mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add support for torch.Generator
type in TorchScript (#110413)
- Add support for `torch.Generator` type in TorchScript - Add `generator` args to all `torch.nn.init` functions that call `uniform_` or `normal_` - Add support for `torch.Generator` in LTC's TorchScript backend (CC: @wconstab) CC: @eellison @davidberard98 @GlebKazantaev @behzad-a Pull Request resolved: https://github.com/pytorch/pytorch/pull/110413 Approved by: https://github.com/wconstab, https://github.com/albanD, https://github.com/glebk-cerebras, https://github.com/davidberard98
This commit is contained in:
committed by
PyTorch MergeBot
parent
7b99b3efb1
commit
27e31ab6e8
@ -2,6 +2,14 @@
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
#include <ATen/CPUGeneratorImpl.h>
|
||||
#ifdef USE_CUDA
|
||||
#include <ATen/cuda/CUDAGeneratorImpl.h>
|
||||
#endif
|
||||
#ifdef USE_MPS
|
||||
#include <ATen/mps/MPSGeneratorImpl.h>
|
||||
#endif
|
||||
|
||||
namespace at {
|
||||
|
||||
void Generator::set_state(const at::Tensor& new_state) {
|
||||
@ -13,4 +21,32 @@ at::Tensor Generator::get_state() const {
|
||||
return at::Tensor::wrap_tensor_impl(this->impl_->get_state());
|
||||
}
|
||||
|
||||
Generator make_generator_for_device(c10::Device device, c10::optional<int64_t> seed) {
|
||||
if (device.is_cpu()) {
|
||||
if (seed.has_value()) {
|
||||
return at::detail::createCPUGenerator(seed.value());
|
||||
} else {
|
||||
return at::detail::createCPUGenerator();
|
||||
}
|
||||
#ifdef USE_CUDA
|
||||
} else if (device.is_cuda()) {
|
||||
auto generator = at::cuda::detail::createCUDAGenerator(device.index());
|
||||
if (seed.has_value()) {
|
||||
generator.set_current_seed(seed.value());
|
||||
}
|
||||
return generator;
|
||||
#endif
|
||||
#ifdef USE_MPS
|
||||
} else if (device.is_mps()) {
|
||||
if (seed.has_value()) {
|
||||
return at::mps::detail::createMPSGenerator(seed.value());
|
||||
} else {
|
||||
return at::mps::detail::createMPSGenerator();
|
||||
}
|
||||
#endif
|
||||
} else {
|
||||
AT_ERROR("Unsupported device for at::make_generator found: ", device.str());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace at
|
||||
|
@ -145,6 +145,9 @@ Generator make_generator(Args&&... args) {
|
||||
return Generator(c10::make_intrusive<Impl>(std::forward<Args>(args)...));
|
||||
}
|
||||
|
||||
Generator make_generator_for_device(
|
||||
c10::Device device, c10::optional<int64_t> seed = c10::nullopt);
|
||||
|
||||
/**
|
||||
* Utility function to static cast input Generator* to
|
||||
* the backend generator type (CPU/CUDAGeneratorImpl etc.)
|
||||
|
@ -644,6 +644,13 @@ std::ostream& IValue::repr(
|
||||
c10::printQuotedString(out, device_stream.str());
|
||||
return out << ")";
|
||||
}
|
||||
case IValue::Tag::Generator: {
|
||||
auto generator = v.toGenerator();
|
||||
out << "torch.Generator(device=";
|
||||
c10::printQuotedString(out, generator.device().str());
|
||||
out << ", seed=" << generator.current_seed() << ")";
|
||||
return out;
|
||||
}
|
||||
case IValue::Tag::GenericDict:
|
||||
return printMaybeAnnotatedDict(out, v, formatter);
|
||||
case IValue::Tag::Enum: {
|
||||
@ -956,6 +963,7 @@ IValue IValue::deepcopy(
|
||||
case IValue::Tag::SymBool:
|
||||
case IValue::Tag::Bool:
|
||||
case IValue::Tag::Device:
|
||||
case IValue::Tag::Generator:
|
||||
case IValue::Tag::Uninitialized: {
|
||||
copy = *this;
|
||||
} break;
|
||||
|
@ -28,6 +28,7 @@ namespace c10 {
|
||||
_(complex, ComplexType) \
|
||||
_(str, StringType) \
|
||||
_(Device, DeviceObjType) \
|
||||
_(Generator, GeneratorType) \
|
||||
_(Stream, StreamObjType) \
|
||||
_(number, NumberType) \
|
||||
_(None, NoneType) \
|
||||
|
@ -168,6 +168,9 @@ full_codegen:
|
||||
- slice_scatter
|
||||
- diagonal_scatter
|
||||
- as_strided_scatter
|
||||
# random ops
|
||||
- normal_functional
|
||||
- uniform
|
||||
ir_gen:
|
||||
- selu
|
||||
supported:
|
||||
@ -177,7 +180,6 @@ supported:
|
||||
- empty.memory_format
|
||||
- empty_strided
|
||||
- fill_.Scalar
|
||||
- normal_
|
||||
- max_pool3d_with_indices
|
||||
- max_pool3d_with_indices_backward
|
||||
- _to_copy
|
||||
|
@ -446,7 +446,6 @@ lazy_tensor_ts_sources = [
|
||||
"torch/csrc/lazy/ts_backend/dynamic_ir.cpp",
|
||||
"torch/csrc/lazy/ts_backend/config.cpp",
|
||||
"torch/csrc/lazy/ts_backend/ops/device_data.cpp",
|
||||
"torch/csrc/lazy/ts_backend/ops/random_ops.cpp",
|
||||
"torch/csrc/lazy/ts_backend/ops/generic.cpp",
|
||||
"torch/csrc/lazy/ts_backend/tensor_aten_ops.cpp",
|
||||
"torch/csrc/lazy/ts_backend/ts_autograd_functions.cpp",
|
||||
|
@ -88,4 +88,3 @@ we suggest using :meth:`torch.jit.trace`.
|
||||
* :class:`torch.nn.AdaptiveLogSoftmaxWithLoss`
|
||||
* :class:`torch.autograd.Function`
|
||||
* :class:`torch.autograd.enable_grad`
|
||||
* :class:`torch.Generator`
|
||||
|
195
test/jit/test_generator.py
Normal file
195
test/jit/test_generator.py
Normal file
@ -0,0 +1,195 @@
|
||||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import io
|
||||
import math
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from torch.nn import init
|
||||
from torch.testing._internal.common_utils import skipIfLegacyJitExecutor
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestGenerator(JitTestCase):
|
||||
# torch.jit.trace does not properly capture the generator manual seed
|
||||
# and thus is non deterministic even if the generator is manually seeded
|
||||
@skipIfLegacyJitExecutor("legacy JIT executor does not support Generator type")
|
||||
@unittest.expectedFailure
|
||||
def test_trace(self):
|
||||
def f():
|
||||
generator = torch.Generator()
|
||||
generator.seed()
|
||||
generator.manual_seed(2023)
|
||||
generator.initial_seed()
|
||||
tensor = torch.empty(2, 2)
|
||||
tensor.uniform_(0, 1, generator=generator)
|
||||
return tensor
|
||||
|
||||
traced_f = torch.jit.trace(f, ())
|
||||
|
||||
# Run this 3 times to ensure that the generator is being manually seeded
|
||||
# each time the traced function is run
|
||||
for i in range(3):
|
||||
torch.manual_seed(1)
|
||||
|
||||
eager_tensor = f()
|
||||
|
||||
# Change the seed of the default generator to
|
||||
# check that we're using the generator from the
|
||||
# trace
|
||||
torch.manual_seed(2)
|
||||
traced_tensor = traced_f()
|
||||
|
||||
self.assertEqual(eager_tensor, traced_tensor)
|
||||
|
||||
def test_script(self):
|
||||
def f():
|
||||
generator = torch.Generator()
|
||||
generator.seed()
|
||||
generator.manual_seed(2023)
|
||||
generator.initial_seed()
|
||||
tensor = torch.empty(2, 2)
|
||||
tensor.normal_(-1.0, 1.0, generator=generator)
|
||||
return tensor
|
||||
|
||||
script_f = torch.jit.script(f, ())
|
||||
|
||||
# Run this 3 times to ensure that the generator is being manually seeded
|
||||
# each time the traced function is run
|
||||
for i in range(3):
|
||||
torch.manual_seed(1)
|
||||
|
||||
eager_tensor = f()
|
||||
|
||||
# Change the seed of the default generator to
|
||||
# check that we're using the generator from the
|
||||
# trace
|
||||
torch.manual_seed(2)
|
||||
|
||||
script_tensor = script_f()
|
||||
|
||||
self.assertEqual(eager_tensor, script_tensor)
|
||||
|
||||
def test_default_generator(self):
|
||||
def f():
|
||||
# check that calling manual seed for the default generator works
|
||||
torch.manual_seed(2023)
|
||||
tensor = torch.empty(2, 2)
|
||||
tensor.normal_(-1.0, 1.0)
|
||||
return tensor
|
||||
|
||||
torch.manual_seed(1)
|
||||
|
||||
eager_tensor = f()
|
||||
|
||||
torch.manual_seed(2)
|
||||
|
||||
script_f = torch.jit.script(f, ())
|
||||
script_tensor = script_f()
|
||||
|
||||
self.assertEqual(eager_tensor, script_tensor)
|
||||
|
||||
def test_generator_arg(self):
|
||||
def f(generator: torch.Generator):
|
||||
tensor = torch.empty(2, 2)
|
||||
tensor.normal_(-1.0, 1.0, generator=generator)
|
||||
return tensor
|
||||
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(2023)
|
||||
|
||||
script_f = torch.jit.script(f, (generator,))
|
||||
|
||||
for i in range(3):
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(2023 + i)
|
||||
|
||||
torch.manual_seed(1 + i)
|
||||
|
||||
eager_tensor = f(generator)
|
||||
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(2023 + i)
|
||||
|
||||
torch.manual_seed(1 + i)
|
||||
|
||||
script_tensor = script_f(generator)
|
||||
|
||||
self.assertEqual(eager_tensor, script_tensor)
|
||||
|
||||
def test_save_load(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.foo = torch.nn.Linear(2, 2, bias=False)
|
||||
self.bar = torch.nn.Linear(2, 2, bias=False)
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_linear(self, module, generator):
|
||||
init.kaiming_uniform_(
|
||||
module.weight, a=math.sqrt(5), generator=generator
|
||||
)
|
||||
|
||||
def reset_parameters(self):
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(1)
|
||||
self.reset_linear(self.foo, generator)
|
||||
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(2)
|
||||
self.reset_linear(self.bar, generator)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.foo(x)
|
||||
x = self.bar(x)
|
||||
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(3)
|
||||
r = torch.empty_like(x)
|
||||
r.normal_(0.0, 1.0, generator=generator)
|
||||
|
||||
return x, r
|
||||
|
||||
eager_foo = Foo()
|
||||
|
||||
script_module = torch.jit.script(Foo())
|
||||
saved_module = io.BytesIO()
|
||||
torch.jit.save(script_module, saved_module)
|
||||
saved_module.seek(0)
|
||||
|
||||
loaded_module = torch.jit.load(saved_module)
|
||||
|
||||
self.assertEqual(eager_foo.foo.weight, loaded_module.foo.weight)
|
||||
self.assertEqual(eager_foo.bar.weight, loaded_module.bar.weight)
|
||||
|
||||
try:
|
||||
# Run this 3 times so make sure that the generator seed is being set
|
||||
# every time forward is called
|
||||
for i in range(3):
|
||||
x = torch.ones(2, 2)
|
||||
out1, r1 = eager_foo(x)
|
||||
out2, r2 = loaded_module(x)
|
||||
|
||||
try:
|
||||
self.assertEqual(out1, out2)
|
||||
except: # noqa: B001, E722
|
||||
print(f"Iteration {i}:\n{out1=}\n{out2=}")
|
||||
raise
|
||||
|
||||
try:
|
||||
self.assertEqual(r1, r2)
|
||||
except: # noqa: B001, E722
|
||||
print(f"Iteration {i}:\n{r1=}\n{r2=}")
|
||||
raise
|
||||
except: # noqa: B001, E722
|
||||
print(loaded_module.forward.code)
|
||||
raise
|
103
test/lazy/test_generator.py
Normal file
103
test/lazy/test_generator.py
Normal file
@ -0,0 +1,103 @@
|
||||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import torch
|
||||
import torch._lazy.metrics as metrics
|
||||
import torch._lazy.ts_backend
|
||||
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
|
||||
|
||||
torch._lazy.ts_backend.init()
|
||||
|
||||
|
||||
class LazyGeneratorTest(TestCase):
|
||||
def test_generator(self):
|
||||
"""
|
||||
Test that generators are being inserted into the TorchScript
|
||||
graph by setting different seeds before each call to
|
||||
generate_tensor but the resulting tensor is the same
|
||||
"""
|
||||
|
||||
def generate_tensor():
|
||||
g1 = torch.Generator()
|
||||
g1.manual_seed(2023)
|
||||
t1 = torch.tensor(1.0)
|
||||
t1.uniform_(generator=g1)
|
||||
|
||||
g2 = torch.Generator()
|
||||
g2.manual_seed(2024)
|
||||
t2 = torch.tensor(1.0)
|
||||
t2.normal_(generator=g2)
|
||||
|
||||
return t1, t2
|
||||
|
||||
torch.manual_seed(1)
|
||||
|
||||
with torch.device("cpu"):
|
||||
cpu_t1, cpu_t2 = generate_tensor()
|
||||
|
||||
torch.manual_seed(2)
|
||||
|
||||
with torch.device("lazy"):
|
||||
lazy_t1, lazy_t2 = generate_tensor()
|
||||
|
||||
torch._lazy.mark_step()
|
||||
|
||||
assert torch.allclose(
|
||||
cpu_t1, lazy_t1.to("cpu")
|
||||
), f"Expected {cpu_t1}, got {lazy_t1.to('cpu')}"
|
||||
assert torch.allclose(
|
||||
cpu_t2, lazy_t2.to("cpu")
|
||||
), f"Expected {cpu_t2}, got {lazy_t2.to('cpu')}"
|
||||
|
||||
@skipIfTorchDynamo("Torch Dynamo does not support torch.Generator type")
|
||||
def test_generator_causes_multiple_compiles(self):
|
||||
"""
|
||||
Test that inserting generators with different seed caused recompile
|
||||
"""
|
||||
|
||||
def generate_tensor(seed):
|
||||
t = torch.tensor(1.0)
|
||||
g = torch.Generator()
|
||||
g.manual_seed(seed)
|
||||
t.uniform_(-1, 1, generator=g)
|
||||
return t
|
||||
|
||||
metrics.reset()
|
||||
|
||||
with torch.device("lazy"):
|
||||
t = generate_tensor(1)
|
||||
torch._lazy.mark_step()
|
||||
|
||||
uncached_compile = metrics.counter_value("UncachedCompile")
|
||||
assert (
|
||||
uncached_compile == 1
|
||||
), f"Expected 1 uncached compiles, got {uncached_compile}"
|
||||
|
||||
t = generate_tensor(2)
|
||||
torch._lazy.mark_step()
|
||||
|
||||
uncached_compile = metrics.counter_value("UncachedCompile")
|
||||
assert (
|
||||
uncached_compile == 2
|
||||
), f"Expected 2 uncached compiles, got {uncached_compile}"
|
||||
|
||||
t = generate_tensor(1)
|
||||
torch._lazy.mark_step()
|
||||
|
||||
uncached_compile = metrics.counter_value("UncachedCompile")
|
||||
assert (
|
||||
uncached_compile == 2
|
||||
), f"Expected 2 uncached compiles, got {uncached_compile}"
|
||||
cached_compile = metrics.counter_value("CachedCompile")
|
||||
assert (
|
||||
cached_compile == 1
|
||||
), f"Expected 1 cached compile, got {cached_compile}"
|
||||
|
||||
metrics.reset()
|
||||
|
||||
latest_graph = torch._C._lazy_ts_backend._get_latest_computation_graph()
|
||||
assert 'torch.Generator(device="cpu", seed=1)' in latest_graph
|
||||
assert "aten::uniform" in latest_graph
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
@ -231,6 +231,9 @@ class TestLazyOpInfo(TestCase):
|
||||
|
||||
samples = op.sample_inputs("lazy", dtype, requires_grad=False)
|
||||
for sample in samples:
|
||||
# Need to run mark step so that all random ops are computed in the right order
|
||||
torch._lazy.mark_step()
|
||||
|
||||
args = [sample.input] + list(sample.args)
|
||||
kwargs = sample.kwargs
|
||||
copy_args = clone_to_device(args, test_device)
|
||||
@ -238,6 +241,7 @@ class TestLazyOpInfo(TestCase):
|
||||
r_exp = op(*copy_args, **kwargs)
|
||||
r_actual = op(*args, **kwargs)
|
||||
|
||||
torch._lazy.mark_step()
|
||||
assert_allclose_rec((r_actual, r_exp))
|
||||
|
||||
@ops([op for op in op_db if op.name in LAZY_OPS_LIST and op.name not in SKIP_RUNTIME_ERROR_LIST | SKIP_INCORRECT_RESULTS_LIST], allowed_dtypes=(torch.float,)) # noqa: B950
|
||||
@ -263,6 +267,9 @@ class TestLazyOpInfo(TestCase):
|
||||
|
||||
samples = op.sample_inputs("lazy", dtype, requires_grad=False)
|
||||
for sample in samples:
|
||||
# Need to run mark step so that all random ops are computed in the right order
|
||||
torch._lazy.mark_step()
|
||||
|
||||
args = [sample.input] + list(sample.args)
|
||||
kwargs = sample.kwargs
|
||||
copy_args = clone_to_device(args, test_device)
|
||||
|
@ -75,6 +75,7 @@ from jit.test_dce import TestDCE # noqa: F401
|
||||
from jit.test_sparse import TestSparse # noqa: F401
|
||||
from jit.test_tensor_methods import TestTensorMethods # noqa: F401
|
||||
from jit.test_dataclasses import TestDataclasses # noqa: F401
|
||||
from jit.test_generator import TestGenerator # noqa: F401
|
||||
|
||||
# Torch
|
||||
from torch import Tensor
|
||||
@ -14169,6 +14170,41 @@ dedent """
|
||||
|
||||
FileCheck().check_not("prim::PythonOp").run(cu.test.graph)
|
||||
|
||||
def test_nn_init_generator(self):
|
||||
init_fns = (
|
||||
'uniform_', 'normal_', 'xavier_normal_', 'xavier_uniform_',
|
||||
)
|
||||
|
||||
for name in init_fns:
|
||||
# Build test code
|
||||
code = dedent('''
|
||||
def test(tensor, generator):
|
||||
# type: (Tensor, Generator)
|
||||
return torch.nn.init.{name}(tensor, generator=generator)
|
||||
''').format(name=name)
|
||||
cu = torch.jit.CompilationUnit(code)
|
||||
|
||||
# Compare functions
|
||||
init_fn = getattr(torch.nn.init, name)
|
||||
|
||||
torch.manual_seed(1)
|
||||
|
||||
g = torch.Generator()
|
||||
g.manual_seed(2023)
|
||||
script_out = cu.test(torch.ones(2, 2), g)
|
||||
|
||||
# Change the seed of the default generator to make
|
||||
# sure that we're using the provided generator
|
||||
torch.manual_seed(2)
|
||||
|
||||
g = torch.Generator()
|
||||
g.manual_seed(2023)
|
||||
eager_out = init_fn(torch.ones(2, 2), generator=g)
|
||||
|
||||
self.assertEqual(script_out, eager_out)
|
||||
|
||||
FileCheck().check_not("prim::PythonOp").run(cu.test.graph)
|
||||
|
||||
def test_early_return_rewrite(self):
|
||||
def test_foo(x: bool):
|
||||
if x:
|
||||
|
@ -94,6 +94,7 @@ class TestPublicBindings(TestCase):
|
||||
"Future",
|
||||
"FutureType",
|
||||
"Generator",
|
||||
"GeneratorType",
|
||||
"get_autocast_cpu_dtype",
|
||||
"get_autocast_ipu_dtype",
|
||||
"get_default_dtype",
|
||||
|
@ -1909,6 +1909,10 @@ class DeviceObjType(JitType):
|
||||
@staticmethod
|
||||
def get() -> DeviceObjType: ...
|
||||
|
||||
class _GeneratorType(JitType):
|
||||
@staticmethod
|
||||
def get() -> _GeneratorType: ...
|
||||
|
||||
class StreamObjType(JitType):
|
||||
@staticmethod
|
||||
def get() -> StreamObjType: ...
|
||||
|
@ -2331,6 +2331,7 @@ def uniform(
|
||||
x: Tensor,
|
||||
low: Union[bool, int, float] = 0.0,
|
||||
high: Union[bool, int, float] = 1.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
):
|
||||
return prims._uniform_helper(
|
||||
x.shape,
|
||||
@ -2338,13 +2339,13 @@ def uniform(
|
||||
high=sym_float(high),
|
||||
dtype=x.dtype,
|
||||
device=x.device,
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
|
||||
@register_decomposition(aten.uniform_)
|
||||
def uniform_(self, low=0, high=1, generator=None):
|
||||
assert generator is None
|
||||
return self.copy_(uniform(self, low, high))
|
||||
return self.copy_(uniform(self, low, high, generator))
|
||||
|
||||
|
||||
# aten/src/ATen/native/UpSample.cpp compute_output_size
|
||||
|
@ -2765,8 +2765,6 @@ svd = _make_prim(
|
||||
#
|
||||
|
||||
|
||||
# TODO: add generator support
|
||||
# NOTE: there is currently no way of acquiring the "default" torch generator
|
||||
def _normal_meta(
|
||||
shape: ShapeType,
|
||||
*,
|
||||
@ -2775,6 +2773,7 @@ def _normal_meta(
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
requires_grad: bool,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> TensorLikeType:
|
||||
torch._check(
|
||||
std >= 0.0,
|
||||
@ -2798,11 +2797,12 @@ def _normal_aten(
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
requires_grad: bool,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Tensor:
|
||||
a = torch.empty(shape, dtype=dtype, device=device, requires_grad=requires_grad)
|
||||
with torch.no_grad():
|
||||
# NOTE: normal_ is incorrectly annotated to expect mean to be a float
|
||||
a.normal_(mean, std) # type: ignore[arg-type]
|
||||
a.normal_(mean, std, generator=generator) # type: ignore[arg-type]
|
||||
return a
|
||||
|
||||
|
||||
@ -2815,7 +2815,7 @@ _normal_doc = """
|
||||
|
||||
normal = _make_prim(
|
||||
schema=(
|
||||
"normal(SymInt[] shape, *, Scalar mean, Scalar std, ScalarType dtype, Device device, bool requires_grad) -> Tensor"
|
||||
"normal(SymInt[] shape, *, Scalar mean, Scalar std, ScalarType dtype, Device device, bool requires_grad, Generator? generator=None) -> Tensor" # noqa: B950
|
||||
),
|
||||
return_type=RETURN_TYPE.NEW,
|
||||
meta=_normal_meta,
|
||||
@ -2831,6 +2831,7 @@ def _uniform_meta(
|
||||
high: float,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> TensorLikeType:
|
||||
strides = utils.make_contiguous_strides_for(shape)
|
||||
return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device)
|
||||
@ -2843,9 +2844,10 @@ def _uniform_aten(
|
||||
high: float,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Tensor:
|
||||
a = torch.empty(shape, dtype=dtype, device=device)
|
||||
a.uniform_(low, high)
|
||||
a.uniform_(low, high, generator=generator)
|
||||
return a
|
||||
|
||||
|
||||
@ -2856,7 +2858,7 @@ _uniform_doc = """
|
||||
# TODO: we should more seriously review randomness modeling and prims
|
||||
_uniform_helper = _make_prim(
|
||||
schema=(
|
||||
"uniform(SymInt[] shape, *, Scalar low, Scalar high, ScalarType dtype, Device device) -> Tensor"
|
||||
"uniform(SymInt[] shape, *, Scalar low, Scalar high, ScalarType dtype, Device device, Generator? generator=None) -> Tensor"
|
||||
),
|
||||
return_type=RETURN_TYPE.NEW,
|
||||
meta=_uniform_meta,
|
||||
|
@ -5931,7 +5931,6 @@ def normal(
|
||||
device=None,
|
||||
pin_memory=None,
|
||||
):
|
||||
assert generator is None
|
||||
assert layout is None or layout == torch.strided
|
||||
|
||||
if not isinstance(std, TensorLike):
|
||||
@ -5968,6 +5967,7 @@ def normal(
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
requires_grad=False,
|
||||
generator=generator,
|
||||
)
|
||||
return std * normal_samples + mean
|
||||
|
||||
|
@ -241,6 +241,17 @@ std::shared_ptr<SugaredValue> SimpleValue::attr(
|
||||
return SpecialFormValue::create(aten::index);
|
||||
}
|
||||
|
||||
if (auto generator_type = value_->type()->cast<GeneratorType>()) {
|
||||
// Handle access to Generator's `manual_seed`, `initial_seed` and `seed`
|
||||
// attributes.
|
||||
if (field == "manual_seed" || field == "initial_seed" || field == "seed") {
|
||||
if (auto builtin = BuiltinFunction::tryCreate(
|
||||
Symbol::aten(field), NamedValue(loc, "self", value_))) {
|
||||
return builtin;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ErrorReport report(loc);
|
||||
report << "'" << value_->type()->repr_str()
|
||||
<< "' object has no attribute or method '" << field << "'.";
|
||||
|
@ -679,12 +679,14 @@ void addInputs(
|
||||
Node* n,
|
||||
const char* name,
|
||||
const c10::optional<at::Generator>& value) {
|
||||
if (value.has_value() && value->defined()) {
|
||||
detail::badArgType(*value);
|
||||
}
|
||||
Graph* g = n->owningGraph();
|
||||
Value* undef_gen = g->insertNode(g->createNone())->output();
|
||||
n->addInput(undef_gen);
|
||||
|
||||
if (value.has_value() && value->defined()) {
|
||||
detail::genericAddInput(n, *value);
|
||||
} else {
|
||||
Value* undef_gen = g->insertNode(g->createNone())->output();
|
||||
n->addInput(undef_gen);
|
||||
}
|
||||
}
|
||||
void addInputs(Node* n, const char* name, at::Device value) {
|
||||
detail::genericAddInput(n, value);
|
||||
|
@ -5,6 +5,7 @@
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/runtime/custom_operator.h>
|
||||
#include <torch/csrc/jit/runtime/operator.h>
|
||||
#include <torch/csrc/jit/runtime/register_ops_utils.h>
|
||||
|
||||
namespace torch::jit {
|
||||
|
||||
@ -108,6 +109,10 @@ c10::optional<Value*> tryInsertConstant(
|
||||
ss << val.toDevice();
|
||||
n->s_(attr::value, ss.str());
|
||||
n->output()->setType(DeviceObjType::get());
|
||||
} else if (val.isGenerator()) {
|
||||
auto generator = val.toGenerator();
|
||||
n->ival_(attr::value, generator);
|
||||
n->output()->setType(GeneratorType::get());
|
||||
} else if (val.isStream()) {
|
||||
// packing into int64_t removed
|
||||
n->ival_(attr::value, val);
|
||||
@ -194,6 +199,9 @@ c10::optional<IValue> toIValue(const Value* v) {
|
||||
} else if (type == DeviceObjType::get()) {
|
||||
auto d = c10::Device(node->s(attr::value));
|
||||
return d;
|
||||
} else if (type == GeneratorType::get()) {
|
||||
auto generator = node->ival(attr::value).toGenerator();
|
||||
return generator;
|
||||
} else if (type == StreamObjType::get()) {
|
||||
// int64_t packing removed
|
||||
auto s = node->ival(attr::value).toStream();
|
||||
|
@ -142,6 +142,9 @@ bool ivaluesEqual(const IValue& a1, const IValue& a2) {
|
||||
if (a1.isObject()) {
|
||||
return &a1.toObjectRef() == &a2.toObjectRef();
|
||||
}
|
||||
if (a1.isGenerator()) {
|
||||
return a1.toGenerator() == a2.toGenerator();
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(false);
|
||||
}
|
||||
|
||||
|
@ -397,6 +397,8 @@ inline InferredType tryToInferType(py::handle input) {
|
||||
return InferredType(IntType::get());
|
||||
} else if (THPDevice_Check(input.ptr())) {
|
||||
return InferredType(DeviceObjType::get());
|
||||
} else if (THPGenerator_Check(input.ptr())) {
|
||||
return InferredType(GeneratorType::get());
|
||||
} else if (THPStream_Check(input.ptr())) {
|
||||
return InferredType(StreamObjType::get());
|
||||
} else if (THPDtype_Check(input.ptr())) {
|
||||
|
@ -1012,6 +1012,10 @@ void initPythonIRBindings(PyObject* module_) {
|
||||
.def_static("get", &StringType::get);
|
||||
py::class_<DeviceObjType, Type, DeviceObjTypePtr>(m, "DeviceObjType")
|
||||
.def_static("get", &DeviceObjType::get);
|
||||
// TODO(antoniojkim): Add GeneratorType to the public API once its been added
|
||||
// to the public documentation
|
||||
py::class_<GeneratorType, Type, GeneratorTypePtr>(m, "_GeneratorType")
|
||||
.def_static("get", &GeneratorType::get);
|
||||
py::class_<StreamObjType, Type, StreamObjTypePtr>(m, "StreamObjType")
|
||||
.def_static("get", &StreamObjType::get);
|
||||
py::class_<PyObjectType, Type, PyObjectTypePtr>(m, "PyObjectType")
|
||||
|
@ -1,4 +1,5 @@
|
||||
#include <ATen/autocast_mode.h>
|
||||
#include <ATen/core/Generator.h>
|
||||
#include <c10/util/Optional.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <torch/csrc/jit/mobile/promoted_prim_ops.h>
|
||||
@ -2492,6 +2493,44 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs1{
|
||||
TORCH_SELECTIVE_SCHEMA("aten::manual_seed(int seed) -> ()"),
|
||||
[](Stack& stack) { at::manual_seed(pop(stack).toInt()); },
|
||||
aliasAnalysisFromSchema()),
|
||||
OperatorGeneratorArgs(
|
||||
TORCH_SELECTIVE_SCHEMA(
|
||||
"aten::Generator(*, Device? device=None, int? seed=None) -> Generator"),
|
||||
[](Stack& stack) {
|
||||
auto seed = pop(stack).toOptional<int64_t>();
|
||||
auto device = pop(stack).toOptional<c10::Device>();
|
||||
push(
|
||||
stack,
|
||||
at::make_generator_for_device(
|
||||
device.value_or(c10::Device("cpu")), seed));
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
OperatorGeneratorArgs(
|
||||
TORCH_SELECTIVE_SCHEMA("aten::initial_seed(Generator self) -> int"),
|
||||
[](Stack& stack) {
|
||||
auto generator = pop(stack);
|
||||
auto current_seed = generator.toGenerator().current_seed();
|
||||
push(stack, (int64_t)current_seed);
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
OperatorGeneratorArgs(
|
||||
TORCH_SELECTIVE_SCHEMA(
|
||||
"aten::manual_seed.generator(Generator(a!) self, int seed) -> Generator(a!)"),
|
||||
[](Stack& stack) {
|
||||
auto seed = pop(stack).toInt();
|
||||
auto generator = pop(stack);
|
||||
generator.toGenerator().set_current_seed(seed);
|
||||
push(stack, generator);
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
OperatorGeneratorArgs(
|
||||
TORCH_SELECTIVE_SCHEMA("aten::seed(Generator(a!) self) -> int"),
|
||||
[](Stack& stack) {
|
||||
auto generator = pop(stack);
|
||||
auto current_seed = generator.toGenerator().seed();
|
||||
push(stack, (int64_t)current_seed);
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
OperatorGeneratorArgs(
|
||||
TORCH_SELECTIVE_SCHEMA("aten::cuda(Tensor(a) self) -> Tensor(a|b)"),
|
||||
[](Stack& stack) {
|
||||
|
@ -392,7 +392,7 @@ RegisterOperators reg({
|
||||
aliasAnalysisFromSchema()),
|
||||
OperatorGenerator(
|
||||
TORCH_SELECTIVE_SCHEMA(
|
||||
"aten::_no_grad_uniform_(Tensor(a!) tensor, float a, float b) -> Tensor(a!)"),
|
||||
"aten::_no_grad_uniform_(Tensor(a!) tensor, float a, float b, Generator? generator=None) -> Tensor(a!)"),
|
||||
[](Stack& stack) {
|
||||
// TODO: remove when script supports setting grad mode
|
||||
torch::NoGradGuard no_grad;
|
||||
@ -402,13 +402,16 @@ RegisterOperators reg({
|
||||
double a;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
double b;
|
||||
c10::optional<at::Generator> generator =
|
||||
pop(stack).toOptional<at::Generator>();
|
||||
|
||||
pop(stack, tensor, a, b);
|
||||
push(stack, tensor.uniform_(a, b));
|
||||
push(stack, tensor.uniform_(a, b, generator));
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
OperatorGenerator(
|
||||
TORCH_SELECTIVE_SCHEMA(
|
||||
"aten::_no_grad_normal_(Tensor(a!) tensor, float mean, float std) -> Tensor(a!)"),
|
||||
"aten::_no_grad_normal_(Tensor(a!) tensor, float mean, float std, Generator? generator=None) -> Tensor(a!)"),
|
||||
[](Stack& stack) {
|
||||
// TODO: remove when script supports setting grad mode
|
||||
torch::NoGradGuard no_grad;
|
||||
@ -418,8 +421,11 @@ RegisterOperators reg({
|
||||
double mean;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
double std;
|
||||
c10::optional<at::Generator> generator =
|
||||
pop(stack).toOptional<at::Generator>();
|
||||
|
||||
pop(stack, tensor, mean, std);
|
||||
push(stack, tensor.normal_(mean, std));
|
||||
push(stack, tensor.normal_(mean, std, generator));
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
OperatorGenerator(
|
||||
|
@ -148,6 +148,11 @@ static inline hash_t Hash(const std::string& value) {
|
||||
static inline hash_t Hash(const c10::string_view& value) {
|
||||
return DataHash(value.data(), value.size());
|
||||
}
|
||||
|
||||
static inline hash_t Hash(const at::Generator& value) {
|
||||
return TensorHash(value.get_state());
|
||||
}
|
||||
|
||||
// Taken from glibc's implementation of hashing optionals,
|
||||
// we want to include a contribution to the hash to distinguish
|
||||
// cases where one or another option was null, but we hope it doesn't
|
||||
|
@ -1369,6 +1369,22 @@ std::vector<Shape> compute_shape_as_strided_scatter_symint(
|
||||
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
|
||||
}
|
||||
|
||||
std::vector<Shape> compute_shape_normal_functional(
|
||||
const at::Tensor& self,
|
||||
double mean,
|
||||
double std,
|
||||
c10::optional<at::Generator> generator) {
|
||||
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||
}
|
||||
|
||||
std::vector<Shape> compute_shape_uniform(
|
||||
const at::Tensor& self,
|
||||
double from,
|
||||
double to,
|
||||
c10::optional<at::Generator> generator) {
|
||||
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||
}
|
||||
|
||||
// Restore unused-parameters warnings
|
||||
#pragma GCC diagnostic pop
|
||||
|
||||
|
@ -70,6 +70,7 @@ TORCH_API std::vector<torch::lazy::Shape> compute_shape_new_empty_strided(const
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_nll_loss2d_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const c10::optional<at::Tensor> & weight, int64_t reduction, int64_t ignore_index, const at::Tensor & total_weight);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_nll_loss2d_forward(const at::Tensor & self, const at::Tensor & target, const c10::optional<at::Tensor> & weight, int64_t reduction, int64_t ignore_index);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_nonzero(const at::Tensor & self);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_normal_functional(const at::Tensor & self, double mean, double std, c10::optional<at::Generator> generator);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_random(const at::Tensor & self, c10::optional<at::Generator> generator);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_random(const at::Tensor & self, int64_t to, c10::optional<at::Generator> generator);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_random(const at::Tensor & self, int64_t from, c10::optional<int64_t> to, c10::optional<at::Generator> generator);
|
||||
@ -91,6 +92,7 @@ TORCH_API std::vector<torch::lazy::Shape> compute_shape_narrow_copy_symint(const
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_hardswish(const at::Tensor & self);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_hardswish_backward(const at::Tensor & grad_output, const at::Tensor & self);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_selu(const at::Tensor & self);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_uniform(const at::Tensor & self, double from, double to, c10::optional<at::Generator> generator);
|
||||
|
||||
// Non-Native ops
|
||||
TORCH_API std::vector<Shape> compute_shape_scalar(const at::Scalar& value, const at::ScalarType& type);
|
||||
|
@ -307,6 +307,15 @@ void initLazyBindings(PyObject* module) {
|
||||
#endif // !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE))
|
||||
return result;
|
||||
});
|
||||
lazy_ts_backend.def("_get_latest_computation_graph", []() {
|
||||
auto computation = LazyGraphExecutor::Get()
|
||||
->GetComputationCache()
|
||||
->GetLatest()
|
||||
->computation;
|
||||
auto ts_computation = dynamic_cast<TSComputation*>(computation.get());
|
||||
TORCH_CHECK(ts_computation, "Found non-TSComputation in cache");
|
||||
return ts_computation->graph()->toString();
|
||||
});
|
||||
|
||||
// GetPythonFramesFunction() has not ever worked with torchdeploy/multipy
|
||||
// possibly becuase GetPythonFrames resolves to external cpython rather
|
||||
|
@ -1,47 +0,0 @@
|
||||
#include <torch/csrc/lazy/core/util.h>
|
||||
#include <torch/csrc/lazy/ts_backend/ops/random_ops.h>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
Normal::Normal(
|
||||
const torch::lazy::Value& self,
|
||||
const double& mean,
|
||||
const double& std,
|
||||
std::vector<torch::lazy::Shape>&& shapes)
|
||||
: torch::lazy::TsNode(
|
||||
ClassOpKind(),
|
||||
{self},
|
||||
std::move(shapes),
|
||||
/* num_outputs */ 1,
|
||||
torch::lazy::MHash(mean, std)),
|
||||
mean_(mean),
|
||||
std_(std) {}
|
||||
|
||||
std::string Normal::ToString() const {
|
||||
std::stringstream ss;
|
||||
ss << TsNode::ToString();
|
||||
ss << ", mean=" << mean_;
|
||||
ss << ", std=" << std_;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
torch::lazy::TSOpVector Normal::Lower(
|
||||
std::shared_ptr<torch::jit::GraphFunction> function,
|
||||
torch::lazy::TSLoweringContext* loctx) const {
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
std::vector<torch::jit::NamedValue> kwarguments;
|
||||
arguments.reserve(3);
|
||||
size_t i = 0;
|
||||
arguments.emplace_back(loctx->GetOutputOp(operand(i++)));
|
||||
arguments.emplace_back("mean", mean_);
|
||||
arguments.emplace_back("std", std_);
|
||||
torch::lazy::TSOpVector normal__out =
|
||||
torch::lazy::LowerTSBuiltin(function, op().op, arguments, kwarguments);
|
||||
TORCH_CHECK_EQ(normal__out.size(), 1);
|
||||
|
||||
return normal__out;
|
||||
}
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
@ -1,30 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/lazy/ts_backend/ts_node.h>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
class Normal : public torch::lazy::TsNode {
|
||||
public:
|
||||
static OpKind ClassOpKind() {
|
||||
return OpKind::Get("aten::normal_");
|
||||
}
|
||||
|
||||
Normal(
|
||||
const torch::lazy::Value& self,
|
||||
const double& mean,
|
||||
const double& std,
|
||||
std::vector<torch::lazy::Shape>&& shapes);
|
||||
|
||||
std::string ToString() const override;
|
||||
torch::lazy::TSOpVector Lower(
|
||||
std::shared_ptr<torch::jit::GraphFunction> function,
|
||||
torch::lazy::TSLoweringContext* loctx) const override;
|
||||
|
||||
double mean_;
|
||||
double std_;
|
||||
};
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
@ -13,7 +13,6 @@
|
||||
#include <torch/csrc/lazy/core/tensor.h>
|
||||
#include <torch/csrc/lazy/core/util.h>
|
||||
#include <torch/csrc/lazy/generated/LazyIr.h>
|
||||
#include <torch/csrc/lazy/ts_backend/ops/random_ops.h>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
|
||||
|
@ -14,7 +14,6 @@
|
||||
#include <torch/csrc/lazy/core/tensor_util.h>
|
||||
#include <torch/csrc/lazy/generated/LazyNativeFunctions.h>
|
||||
#include <torch/csrc/lazy/ts_backend/config.h>
|
||||
#include <torch/csrc/lazy/ts_backend/ops/random_ops.h>
|
||||
#include <torch/csrc/lazy/ts_backend/ops/to_copy.h>
|
||||
#include <torch/csrc/lazy/ts_backend/tensor_aten_ops.h>
|
||||
#include <torch/csrc/lazy/ts_backend/ts_autograd_functions.h>
|
||||
@ -372,36 +371,6 @@ at::Tensor LazyNativeFunctions::max_pool3d_with_indices_backward(
|
||||
indices);
|
||||
}
|
||||
|
||||
at::Tensor& LazyNativeFunctions::normal_(
|
||||
at::Tensor& self,
|
||||
double mean,
|
||||
double std,
|
||||
c10::optional<at::Generator> generator) {
|
||||
// Unconditionally fall back.
|
||||
// implementing normal_ via lazy tensor caused differences in results compared
|
||||
// to eager.
|
||||
return at::native::call_fallback_fn<<c_eager_fallback, ATEN_OP(normal_)>::
|
||||
call(self, mean, std, generator);
|
||||
|
||||
// if (force_eager_fallback(c10::Symbol::fromQualString("aten::normal_"))) {
|
||||
// return at::native::call_fallback_fn<<c_eager_fallback,
|
||||
// ATEN_OP(normal_)>::call(self, mean, std, generator);
|
||||
// }
|
||||
|
||||
// if (generator.has_value()) {
|
||||
// return at::native::call_fallback_fn<<c_eager_fallback,
|
||||
// ATEN_OP(normal_)>::call(self, mean, std, generator);
|
||||
// }
|
||||
|
||||
// TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
// auto device = bridge::GetBackendDevice(self);
|
||||
// LazyTensor lazy_self = GetLtcTensorOrCreateForWrappedNumber(self, *device);
|
||||
// std::vector<torch::lazy::Shape> shapes =
|
||||
// {torch::lazy::Shape(self.scalar_type(), self.sizes().vec())}; auto node =
|
||||
// torch::lazy::MakeNode<Normal>(lazy_self.GetIrValue(), mean, std,
|
||||
// std::move(shapes)); lazy_self.SetInPlaceIrValue(node); return self;
|
||||
};
|
||||
|
||||
at::Tensor LazyNativeFunctions::_unsafe_view(
|
||||
const at::Tensor& self,
|
||||
at::IntArrayRef size) {
|
||||
|
@ -13,6 +13,7 @@ from typing import Type
|
||||
import torch
|
||||
|
||||
from torch._C import (
|
||||
_GeneratorType,
|
||||
AnyType,
|
||||
AwaitType,
|
||||
BoolType,
|
||||
@ -479,6 +480,8 @@ def try_ann_to_type(ann, loc, rcb=None):
|
||||
return InterfaceType(ann.__torch_script_interface__)
|
||||
if ann is torch.device:
|
||||
return DeviceObjType.get()
|
||||
if ann is torch.Generator:
|
||||
return _GeneratorType.get()
|
||||
if ann is torch.Stream:
|
||||
return StreamObjType.get()
|
||||
if ann is torch.dtype:
|
||||
|
@ -10,14 +10,14 @@ from typing import Optional as _Optional
|
||||
# functions that use `with torch.no_grad()`. The JIT doesn't support context
|
||||
# managers, so these need to be implemented as builtins. Using these wrappers
|
||||
# lets us keep those builtins small and re-usable.
|
||||
def _no_grad_uniform_(tensor, a, b):
|
||||
def _no_grad_uniform_(tensor, a, b, generator=None):
|
||||
with torch.no_grad():
|
||||
return tensor.uniform_(a, b)
|
||||
return tensor.uniform_(a, b, generator=generator)
|
||||
|
||||
|
||||
def _no_grad_normal_(tensor, mean, std):
|
||||
def _no_grad_normal_(tensor, mean, std, generator=None):
|
||||
with torch.no_grad():
|
||||
return tensor.normal_(mean, std)
|
||||
return tensor.normal_(mean, std, generator=generator)
|
||||
|
||||
|
||||
def _no_grad_trunc_normal_(tensor, mean, std, a, b, generator=None):
|
||||
@ -121,7 +121,12 @@ def calculate_gain(nonlinearity, param=None):
|
||||
raise ValueError(f"Unsupported nonlinearity {nonlinearity}")
|
||||
|
||||
|
||||
def uniform_(tensor: Tensor, a: float = 0., b: float = 1.) -> Tensor:
|
||||
def uniform_(
|
||||
tensor: Tensor,
|
||||
a: float = 0.0,
|
||||
b: float = 1.0,
|
||||
generator: _Optional[torch.Generator] = None,
|
||||
) -> Tensor:
|
||||
r"""Fill the input Tensor with values drawn from the uniform distribution.
|
||||
|
||||
:math:`\mathcal{U}(a, b)`.
|
||||
@ -130,17 +135,25 @@ def uniform_(tensor: Tensor, a: float = 0., b: float = 1.) -> Tensor:
|
||||
tensor: an n-dimensional `torch.Tensor`
|
||||
a: the lower bound of the uniform distribution
|
||||
b: the upper bound of the uniform distribution
|
||||
generator: the torch Generator to sample from (default: None)
|
||||
|
||||
Examples:
|
||||
>>> w = torch.empty(3, 5)
|
||||
>>> nn.init.uniform_(w)
|
||||
"""
|
||||
if torch.overrides.has_torch_function_variadic(tensor):
|
||||
return torch.overrides.handle_torch_function(uniform_, (tensor,), tensor=tensor, a=a, b=b)
|
||||
return _no_grad_uniform_(tensor, a, b)
|
||||
return torch.overrides.handle_torch_function(
|
||||
uniform_, (tensor,), tensor=tensor, a=a, b=b, generator=generator
|
||||
)
|
||||
return _no_grad_uniform_(tensor, a, b, generator)
|
||||
|
||||
|
||||
def normal_(tensor: Tensor, mean: float = 0., std: float = 1.) -> Tensor:
|
||||
def normal_(
|
||||
tensor: Tensor,
|
||||
mean: float = 0.0,
|
||||
std: float = 1.0,
|
||||
generator: _Optional[torch.Generator] = None,
|
||||
) -> Tensor:
|
||||
r"""Fill the input Tensor with values drawn from the normal distribution.
|
||||
|
||||
:math:`\mathcal{N}(\text{mean}, \text{std}^2)`.
|
||||
@ -149,14 +162,17 @@ def normal_(tensor: Tensor, mean: float = 0., std: float = 1.) -> Tensor:
|
||||
tensor: an n-dimensional `torch.Tensor`
|
||||
mean: the mean of the normal distribution
|
||||
std: the standard deviation of the normal distribution
|
||||
generator: the torch Generator to sample from (default: None)
|
||||
|
||||
Examples:
|
||||
>>> w = torch.empty(3, 5)
|
||||
>>> nn.init.normal_(w)
|
||||
"""
|
||||
if torch.overrides.has_torch_function_variadic(tensor):
|
||||
return torch.overrides.handle_torch_function(normal_, (tensor,), tensor=tensor, mean=mean, std=std)
|
||||
return _no_grad_normal_(tensor, mean, std)
|
||||
return torch.overrides.handle_torch_function(
|
||||
normal_, (tensor,), tensor=tensor, mean=mean, std=std, generator=generator
|
||||
)
|
||||
return _no_grad_normal_(tensor, mean, std, generator)
|
||||
|
||||
def trunc_normal_(
|
||||
tensor: Tensor,
|
||||
@ -180,6 +196,7 @@ def trunc_normal_(
|
||||
std: the standard deviation of the normal distribution
|
||||
a: the minimum cutoff value
|
||||
b: the maximum cutoff value
|
||||
generator: the torch Generator to sample from (default: None)
|
||||
|
||||
Examples:
|
||||
>>> w = torch.empty(3, 5)
|
||||
@ -314,7 +331,9 @@ def _calculate_fan_in_and_fan_out(tensor):
|
||||
return fan_in, fan_out
|
||||
|
||||
|
||||
def xavier_uniform_(tensor: Tensor, gain: float = 1.) -> Tensor:
|
||||
def xavier_uniform_(
|
||||
tensor: Tensor, gain: float = 1.0, generator: _Optional[torch.Generator] = None
|
||||
) -> Tensor:
|
||||
r"""Fill the input `Tensor` with values using a Xavier uniform distribution.
|
||||
|
||||
The method is described in `Understanding the difficulty of training
|
||||
@ -330,6 +349,7 @@ def xavier_uniform_(tensor: Tensor, gain: float = 1.) -> Tensor:
|
||||
Args:
|
||||
tensor: an n-dimensional `torch.Tensor`
|
||||
gain: an optional scaling factor
|
||||
generator: the torch Generator to sample from (default: None)
|
||||
|
||||
Examples:
|
||||
>>> w = torch.empty(3, 5)
|
||||
@ -339,10 +359,14 @@ def xavier_uniform_(tensor: Tensor, gain: float = 1.) -> Tensor:
|
||||
std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
|
||||
a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
|
||||
|
||||
return _no_grad_uniform_(tensor, -a, a)
|
||||
return _no_grad_uniform_(tensor, -a, a, generator)
|
||||
|
||||
|
||||
def xavier_normal_(tensor: Tensor, gain: float = 1.) -> Tensor:
|
||||
def xavier_normal_(
|
||||
tensor: Tensor,
|
||||
gain: float = 1.0,
|
||||
generator: _Optional[torch.Generator] = None,
|
||||
) -> Tensor:
|
||||
r"""Fill the input `Tensor` with values using a Xavier normal distribution.
|
||||
|
||||
The method is described in `Understanding the difficulty of training deep feedforward
|
||||
@ -357,6 +381,7 @@ def xavier_normal_(tensor: Tensor, gain: float = 1.) -> Tensor:
|
||||
Args:
|
||||
tensor: an n-dimensional `torch.Tensor`
|
||||
gain: an optional scaling factor
|
||||
generator: the torch Generator to sample from (default: None)
|
||||
|
||||
Examples:
|
||||
>>> w = torch.empty(3, 5)
|
||||
@ -365,7 +390,7 @@ def xavier_normal_(tensor: Tensor, gain: float = 1.) -> Tensor:
|
||||
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
|
||||
std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
|
||||
|
||||
return _no_grad_normal_(tensor, 0., std)
|
||||
return _no_grad_normal_(tensor, 0., std, generator)
|
||||
|
||||
|
||||
def _calculate_correct_fan(tensor, mode):
|
||||
@ -379,7 +404,11 @@ def _calculate_correct_fan(tensor, mode):
|
||||
|
||||
|
||||
def kaiming_uniform_(
|
||||
tensor: Tensor, a: float = 0, mode: str = 'fan_in', nonlinearity: str = 'leaky_relu'
|
||||
tensor: Tensor,
|
||||
a: float = 0,
|
||||
mode: str = "fan_in",
|
||||
nonlinearity: str = "leaky_relu",
|
||||
generator: _Optional[torch.Generator] = None,
|
||||
):
|
||||
r"""Fill the input `Tensor` with values using a Kaiming uniform distribution.
|
||||
|
||||
@ -403,6 +432,7 @@ def kaiming_uniform_(
|
||||
backwards pass.
|
||||
nonlinearity: the non-linear function (`nn.functional` name),
|
||||
recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
|
||||
generator: the torch Generator to sample from (default: None)
|
||||
|
||||
Examples:
|
||||
>>> w = torch.empty(3, 5)
|
||||
@ -415,7 +445,8 @@ def kaiming_uniform_(
|
||||
tensor=tensor,
|
||||
a=a,
|
||||
mode=mode,
|
||||
nonlinearity=nonlinearity)
|
||||
nonlinearity=nonlinearity,
|
||||
generator=generator)
|
||||
|
||||
if 0 in tensor.shape:
|
||||
warnings.warn("Initializing zero-element tensors is a no-op")
|
||||
@ -425,11 +456,15 @@ def kaiming_uniform_(
|
||||
std = gain / math.sqrt(fan)
|
||||
bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
|
||||
with torch.no_grad():
|
||||
return tensor.uniform_(-bound, bound)
|
||||
return tensor.uniform_(-bound, bound, generator=generator)
|
||||
|
||||
|
||||
def kaiming_normal_(
|
||||
tensor: Tensor, a: float = 0, mode: str = 'fan_in', nonlinearity: str = 'leaky_relu'
|
||||
tensor: Tensor,
|
||||
a: float = 0,
|
||||
mode: str = "fan_in",
|
||||
nonlinearity: str = "leaky_relu",
|
||||
generator: _Optional[torch.Generator] = None,
|
||||
):
|
||||
r"""Fill the input `Tensor` with values using a Kaiming normal distribution.
|
||||
|
||||
@ -453,6 +488,7 @@ def kaiming_normal_(
|
||||
backwards pass.
|
||||
nonlinearity: the non-linear function (`nn.functional` name),
|
||||
recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
|
||||
generator: the torch Generator to sample from (default: None)
|
||||
|
||||
Examples:
|
||||
>>> w = torch.empty(3, 5)
|
||||
@ -465,10 +501,14 @@ def kaiming_normal_(
|
||||
gain = calculate_gain(nonlinearity, a)
|
||||
std = gain / math.sqrt(fan)
|
||||
with torch.no_grad():
|
||||
return tensor.normal_(0, std)
|
||||
return tensor.normal_(0, std, generator=generator)
|
||||
|
||||
|
||||
def orthogonal_(tensor, gain=1):
|
||||
def orthogonal_(
|
||||
tensor,
|
||||
gain=1,
|
||||
generator: _Optional[torch.Generator] = None,
|
||||
):
|
||||
r"""Fill the input `Tensor` with a (semi) orthogonal matrix.
|
||||
|
||||
Described in `Exact solutions to the nonlinear dynamics of learning in deep
|
||||
@ -479,6 +519,7 @@ def orthogonal_(tensor, gain=1):
|
||||
Args:
|
||||
tensor: an n-dimensional `torch.Tensor`, where :math:`n \geq 2`
|
||||
gain: optional scaling factor
|
||||
generator: the torch Generator to sample from (default: None)
|
||||
|
||||
Examples:
|
||||
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
|
||||
@ -493,7 +534,7 @@ def orthogonal_(tensor, gain=1):
|
||||
return tensor
|
||||
rows = tensor.size(0)
|
||||
cols = tensor.numel() // rows
|
||||
flattened = tensor.new(rows, cols).normal_(0, 1)
|
||||
flattened = tensor.new(rows, cols).normal_(0, 1, generator=generator)
|
||||
|
||||
if rows < cols:
|
||||
flattened.t_()
|
||||
@ -514,7 +555,12 @@ def orthogonal_(tensor, gain=1):
|
||||
return tensor
|
||||
|
||||
|
||||
def sparse_(tensor, sparsity, std=0.01):
|
||||
def sparse_(
|
||||
tensor,
|
||||
sparsity,
|
||||
std=0.01,
|
||||
generator: _Optional[torch.Generator] = None,
|
||||
):
|
||||
r"""Fill the 2D input `Tensor` as a sparse matrix.
|
||||
|
||||
The non-zero elements will be drawn from the normal distribution
|
||||
@ -526,6 +572,7 @@ def sparse_(tensor, sparsity, std=0.01):
|
||||
sparsity: The fraction of elements in each column to be set to zero
|
||||
std: the standard deviation of the normal distribution used to generate
|
||||
the non-zero values
|
||||
generator: the torch Generator to sample from (default: None)
|
||||
|
||||
Examples:
|
||||
>>> w = torch.empty(3, 5)
|
||||
@ -538,7 +585,7 @@ def sparse_(tensor, sparsity, std=0.01):
|
||||
num_zeros = int(math.ceil(sparsity * rows))
|
||||
|
||||
with torch.no_grad():
|
||||
tensor.normal_(0, std)
|
||||
tensor.normal_(0, std, generator=generator)
|
||||
for col_idx in range(cols):
|
||||
row_indices = torch.randperm(rows)
|
||||
zero_indices = row_indices[:num_zeros]
|
||||
|
@ -933,10 +933,10 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||
distance_function=None, margin=1.0,
|
||||
swap=False, reduction='mean': -1),
|
||||
torch.nn.functional.unfold: lambda input, kernel_size, dilation=1, padding=0, stride=1: -1,
|
||||
torch.nn.init.uniform_: lambda tensor, a=0., b=1.: -1,
|
||||
torch.nn.init.normal_: lambda tensor, mean=0., std=1.: -1,
|
||||
torch.nn.init.uniform_: lambda tensor, a=0., b=1., generator=None: -1,
|
||||
torch.nn.init.normal_: lambda tensor, mean=0., std=1., generator=None: -1,
|
||||
torch.nn.init.constant_: lambda tensor, val: -1,
|
||||
torch.nn.init.kaiming_uniform_: lambda tensor, a=0, mode='fan_in', nonlinearity='leaky_relu': -1,
|
||||
torch.nn.init.kaiming_uniform_: lambda tensor, a=0, mode='fan_in', nonlinearity='leaky_relu', generator=None: -1,
|
||||
torch.nonzero: lambda input, as_tuple=False: -1,
|
||||
torch.nonzero_static: lambda input, *, size, fill_value=-1: -1,
|
||||
torch.argwhere: lambda input: -1,
|
||||
|
@ -9322,7 +9322,14 @@ def wrapper_set_seed(op, *args, **kwargs):
|
||||
"""
|
||||
with freeze_rng_state():
|
||||
torch.manual_seed(42)
|
||||
return op(*args, **kwargs)
|
||||
output = op(*args, **kwargs)
|
||||
|
||||
if isinstance(output, torch.Tensor) and output.device.type == "lazy":
|
||||
# We need to call mark step inside freeze_rng_state so that numerics
|
||||
# match eager execution
|
||||
torch._lazy.mark_step()
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def reference_layer_norm(inp: np.ndarray, normalized_shape: Tuple[int], weight=None, bias=None, eps=1e-5):
|
||||
@ -17840,8 +17847,14 @@ op_db: List[OpInfo] = [
|
||||
DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'),
|
||||
# Lazy tensor failures
|
||||
DecorateInfo(unittest.skip('Skipped!'), 'TestLazyOpInfo', 'test_dispatched_to_lazy'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_correctness'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_correctness_with_reusing_ir'),
|
||||
# These tests fail only when built with ASAN
|
||||
DecorateInfo(unittest.skip("Fails with ASAN"), 'TestLazyOpInfo', 'test_correctness', active_if=TEST_WITH_ASAN),
|
||||
DecorateInfo(
|
||||
unittest.skip("Fails with ASAN"),
|
||||
'TestLazyOpInfo',
|
||||
'test_correctness_with_reusing_ir',
|
||||
active_if=TEST_WITH_ASAN
|
||||
),
|
||||
),
|
||||
),
|
||||
OpInfo(
|
||||
|
@ -1313,6 +1313,28 @@ def skipIfTorchInductor(msg="test doesn't currently work with torchinductor",
|
||||
def skipRocmIfTorchInductor(msg="test doesn't currently work with torchinductor on the ROCm stack"):
|
||||
return skipIfTorchInductor(msg=msg, condition=TEST_WITH_ROCM and TEST_WITH_TORCHINDUCTOR)
|
||||
|
||||
def skipIfLegacyJitExecutor(msg="test doesn't currently work with legacy JIT executor"):
|
||||
def decorator(fn):
|
||||
if not isinstance(fn, type):
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
if GRAPH_EXECUTOR == ProfilingMode.LEGACY:
|
||||
raise unittest.SkipTest(msg)
|
||||
else:
|
||||
fn(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
assert(isinstance(fn, type))
|
||||
if GRAPH_EXECUTOR == ProfilingMode.LEGACY:
|
||||
fn.__unittest_skip__ = True
|
||||
fn.__unittest_skip_why__ = msg
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
# Run PyTorch tests with translation validation on.
|
||||
TEST_WITH_TV = os.getenv('PYTORCH_TEST_WITH_TV') == '1'
|
||||
|
||||
|
@ -7,6 +7,7 @@ from torchgen.api.types import (
|
||||
CType,
|
||||
deviceT,
|
||||
doubleT,
|
||||
generatorT,
|
||||
layoutT,
|
||||
ListCType,
|
||||
longT,
|
||||
@ -109,6 +110,8 @@ def process_ir_type(
|
||||
return BaseCType(stringT)
|
||||
elif typ.name == BaseTy.Device:
|
||||
return BaseCType(deviceT)
|
||||
elif typ.name == BaseTy.Generator:
|
||||
return BaseCType(generatorT)
|
||||
elif typ.name == BaseTy.Layout:
|
||||
return BaseCType(layoutT)
|
||||
elif typ.name == BaseTy.MemoryFormat:
|
||||
@ -218,16 +221,7 @@ class LazyArgument:
|
||||
self.symint = symint
|
||||
self.is_optional = isinstance(arg.type, OptionalType)
|
||||
self.is_generator = isGeneratorType(arg.type)
|
||||
if self.is_generator:
|
||||
assert (
|
||||
self.is_optional
|
||||
), "We expect all generators are optional since currently they are"
|
||||
# there is no handling for generators in TorchScript IR (or XLA)
|
||||
# so we fall back to eager if the (optional)generator has value, and otherwise
|
||||
# its null and safe to exclude from lazy IR
|
||||
self.lazy_type_ = None
|
||||
else:
|
||||
self.lazy_type_ = process_ir_type(arg.type, properties, symint=symint)
|
||||
self.lazy_type_ = process_ir_type(arg.type, properties, symint=symint)
|
||||
self.is_wrapped_scalar = isWrappedScalarType(arg.type)
|
||||
self.is_symint_or_list = symint and (
|
||||
isSymIntType(arg.type)
|
||||
@ -236,9 +230,7 @@ class LazyArgument:
|
||||
# or (isinstance(arg.type, ListType) and isSymIntType(arg.type.elem))
|
||||
)
|
||||
|
||||
self.is_lazy_value = not self.is_generator and isValueType(
|
||||
self.lazy_type, properties
|
||||
)
|
||||
self.is_lazy_value = isValueType(self.lazy_type, properties)
|
||||
|
||||
@property
|
||||
def lazy_type(self) -> CType:
|
||||
@ -419,7 +411,7 @@ class LazyIrSchema:
|
||||
keyword: bool = True,
|
||||
values: bool = True,
|
||||
scalars: bool = True,
|
||||
generator: bool = False,
|
||||
generator: bool = True,
|
||||
) -> List[LazyArgument]:
|
||||
# This function maintains the sorted order of arguments but provides different filtered views.
|
||||
# Some parts of the code care about kwargs vs args (TS lowerings),
|
||||
|
@ -122,12 +122,8 @@ def gen_fallback_code(
|
||||
aten_op_str = f"ATEN_OP2({schema.aten_name}, {overload_name})"
|
||||
else:
|
||||
aten_op_str = f"ATEN_OP({schema.aten_name})"
|
||||
or_has_generator = ""
|
||||
if schema.generator_arg:
|
||||
# generators are always optional and there is never more than one, at least currently
|
||||
or_has_generator = f" || ({schema.generator_arg.name}.has_value() && {schema.generator_arg.name}->defined())"
|
||||
return f"""
|
||||
if (force_eager_fallback({aten_symbol(schema)}){or_has_generator}) {{
|
||||
if (force_eager_fallback({aten_symbol(schema)})) {{
|
||||
return at::native::call_fallback_fn_symint<<c_eager_fallback, {aten_op_str}>::call(
|
||||
{fallback_args}
|
||||
);
|
||||
@ -290,9 +286,12 @@ class GenLazyIR(ABC):
|
||||
members_to_string = []
|
||||
for arg in scalar_args:
|
||||
if isinstance(arg.lazy_type, OptionalCType):
|
||||
value = f"{arg.name}.value()"
|
||||
if arg.is_generator:
|
||||
value = '"torch.Generator()"'
|
||||
members_to_string.append(
|
||||
f"""if ({arg.name}.has_value()) {{
|
||||
ss << ", {arg.name}=" << {arg.name}.value();
|
||||
ss << ", {arg.name}=" << {value};
|
||||
}} else {{
|
||||
ss << ", {arg.name}=null";
|
||||
}}"""
|
||||
|
Reference in New Issue
Block a user