mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-26 00:24:53 +08:00
Compare commits
24 Commits
v2.4.0-rc8
...
tensordict
| Author | SHA1 | Date | |
|---|---|---|---|
| ba96cfd2f7 | |||
| 7af4f58380 | |||
| 4352ebe8bc | |||
| d52fcc6d33 | |||
| cd0e9c4c05 | |||
| 8fc1309f9f | |||
| 11a52e8d6d | |||
| 52061a05a4 | |||
| afe4a40805 | |||
| f5c809e33d | |||
| 1d9582c627 | |||
| 2331b048af | |||
| fb2103bee1 | |||
| 2a69d65d52 | |||
| f36bd08109 | |||
| d6160943b1 | |||
| d147161883 | |||
| 6d3b90b64b | |||
| d040d35294 | |||
| cf6704548a | |||
| b8ab16bad6 | |||
| 570605e37d | |||
| d307e5e0be | |||
| 19f3d13102 |
4473
test/test_dict.py
Normal file
4473
test/test_dict.py
Normal file
File diff suppressed because it is too large
Load Diff
1637
test/test_tensorclass.py
Normal file
1637
test/test_tensorclass.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -1262,7 +1262,7 @@ def grad_and_value(func: Callable, argnums: argnums_t = 0, has_aux: bool = False
|
||||
)
|
||||
output, aux = output
|
||||
|
||||
if not isinstance(output, torch.Tensor):
|
||||
if not isinstance(output, (torch.Tensor, torch.dict.TensorDictBase)):
|
||||
raise RuntimeError('grad_and_value(f)(*args): Expected f(*args) '
|
||||
f'to return a Tensor, got {type(output)}')
|
||||
if output.dim() != 0:
|
||||
|
||||
@ -5,12 +5,15 @@ import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch._functorch.utils import exposed_in
|
||||
from torch.dict import TensorDictBase
|
||||
|
||||
|
||||
@exposed_in("torch.func")
|
||||
def functional_call(
|
||||
module: "torch.nn.Module",
|
||||
parameter_and_buffer_dicts: Union[Dict[str, Tensor], Sequence[Dict[str, Tensor]]],
|
||||
parameter_and_buffer_dicts: Union[
|
||||
Dict[str, Tensor], Sequence[Dict[str, Tensor]], TensorDictBase
|
||||
],
|
||||
args: Union[Any, Tuple],
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
*,
|
||||
@ -118,7 +121,7 @@ def functional_call(
|
||||
Returns:
|
||||
Any: the result of calling ``module``.
|
||||
"""
|
||||
if isinstance(parameter_and_buffer_dicts, dict):
|
||||
if isinstance(parameter_and_buffer_dicts, (dict, TensorDictBase)):
|
||||
parameters_and_buffers = parameter_and_buffer_dicts
|
||||
elif isinstance(parameter_and_buffer_dicts, Sequence):
|
||||
if not all(isinstance(d, dict) for d in parameter_and_buffer_dicts):
|
||||
|
||||
@ -9,12 +9,13 @@ import functools
|
||||
import threading
|
||||
from torch import Tensor
|
||||
from typing import Any, Callable, Optional, Tuple, Union, List
|
||||
|
||||
from torch.utils._pytree import (
|
||||
tree_flatten,
|
||||
tree_unflatten,
|
||||
tree_map_,
|
||||
_broadcast_to_and_flatten,
|
||||
TreeSpec,
|
||||
TreeSpec, SUPPORTED_NODES,
|
||||
)
|
||||
from functools import partial
|
||||
import os
|
||||
@ -28,9 +29,24 @@ from torch._C._functorch import (
|
||||
is_batchedtensor,
|
||||
)
|
||||
|
||||
|
||||
in_dims_t = Union[int, Tuple]
|
||||
out_dims_t = Union[int, Tuple[int, ...]]
|
||||
|
||||
class _exclude_td_from_pytree:
|
||||
def __init__(self):
|
||||
from torch.dict._pytree import PYTREE_REGISTERED_TDS
|
||||
|
||||
self.PYTREE_REGISTERED_TDS = PYTREE_REGISTERED_TDS
|
||||
self.tdnodes = {}
|
||||
|
||||
def __enter__(self):
|
||||
for tdtype in self.PYTREE_REGISTERED_TDS:
|
||||
self.tdnodes[tdtype] = SUPPORTED_NODES.pop(tdtype)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
for tdtype in self.PYTREE_REGISTERED_TDS:
|
||||
SUPPORTED_NODES[tdtype] = self.tdnodes[tdtype]
|
||||
|
||||
def doesnt_support_saved_tensors_hooks(f):
|
||||
message = (
|
||||
@ -80,6 +96,7 @@ def _as_tuple(value: Any, num_elements: int, error_message_lambda: Callable[[],
|
||||
def _process_batched_inputs(
|
||||
in_dims: in_dims_t, args: Tuple, func: Callable
|
||||
) -> Tuple[int, List[Any], List[Any], TreeSpec]:
|
||||
|
||||
if not isinstance(in_dims, int) and not isinstance(in_dims, tuple):
|
||||
raise ValueError(
|
||||
f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '
|
||||
@ -91,14 +108,17 @@ def _process_batched_inputs(
|
||||
f'inputs, or you are trying to vmap over a function with no inputs. '
|
||||
f'The latter is unsupported.')
|
||||
|
||||
flat_args, args_spec = tree_flatten(args)
|
||||
flat_in_dims = _broadcast_to_and_flatten(in_dims, args_spec)
|
||||
if flat_in_dims is None:
|
||||
raise ValueError(
|
||||
f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '
|
||||
f'in_dims is not compatible with the structure of `inputs`. '
|
||||
f'in_dims has structure {tree_flatten(in_dims)[1]} but inputs '
|
||||
f'has structure {args_spec}.')
|
||||
# we want to escape TensorDicts as they take care of adding the batch dimension
|
||||
with _exclude_td_from_pytree():
|
||||
flat_args, args_spec = tree_flatten(args)
|
||||
flat_in_dims = _broadcast_to_and_flatten(in_dims, args_spec)
|
||||
if flat_in_dims is None:
|
||||
raise ValueError(
|
||||
f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '
|
||||
f'in_dims is not compatible with the structure of `inputs`. '
|
||||
f'in_dims has structure {tree_flatten(in_dims)[1]} but inputs '
|
||||
f'has structure {args_spec}.'
|
||||
)
|
||||
|
||||
for i, (arg, in_dim) in enumerate(zip(flat_args, flat_in_dims)):
|
||||
if not isinstance(in_dim, int) and in_dim is not None:
|
||||
@ -106,7 +126,10 @@ def _process_batched_inputs(
|
||||
f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '
|
||||
f'Got in_dim={in_dim} for an input but in_dim must be either '
|
||||
f'an integer dimension or None.')
|
||||
if isinstance(in_dim, int) and not isinstance(arg, Tensor):
|
||||
from torch.dict.base import TensorDictBase
|
||||
if isinstance(in_dim, int) and not isinstance(
|
||||
arg, (Tensor, TensorDictBase)
|
||||
):
|
||||
raise ValueError(
|
||||
f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '
|
||||
f'Got in_dim={in_dim} for an input but the input is of type '
|
||||
@ -130,10 +153,29 @@ def _process_batched_inputs(
|
||||
def _create_batched_inputs(
|
||||
flat_in_dims: List[Any], flat_args: List[Any], vmap_level: int, args_spec) -> Tuple:
|
||||
# See NOTE [Ignored _remove_batch_dim, _add_batch_dim]
|
||||
batched_inputs = [arg if in_dim is None else
|
||||
_add_batch_dim(arg, in_dim, vmap_level)
|
||||
for in_dim, arg in zip(flat_in_dims, flat_args)]
|
||||
return tree_unflatten(batched_inputs, args_spec)
|
||||
# If tensordict, we remove the dim at batch_size[in_dim] such that the TensorDict can accept
|
||||
# the batched tensors. This will be added in _unwrap_batched
|
||||
|
||||
from torch.dict.base import TensorDictBase
|
||||
batched_inputs = []
|
||||
for in_dim, arg in zip(flat_in_dims, flat_args):
|
||||
if in_dim is None:
|
||||
if isinstance(arg, TensorDictBase):
|
||||
# this may be a perf bottleneck and could benefit from caching
|
||||
# arg = cache(arg.clone)(False)
|
||||
arg = arg.clone(False)
|
||||
|
||||
batched_input = arg
|
||||
else:
|
||||
if isinstance(arg, TensorDictBase):
|
||||
batched_input = arg._add_batch_dim(
|
||||
in_dim=in_dim, vmap_level=vmap_level
|
||||
)
|
||||
else:
|
||||
batched_input = _add_batch_dim(arg, in_dim, vmap_level)
|
||||
batched_inputs.append(batched_input)
|
||||
with _exclude_td_from_pytree():
|
||||
return tree_unflatten(batched_inputs, args_spec)
|
||||
|
||||
|
||||
def _maybe_remove_batch_dim(name, batched_output, vmap_level, batch_size, out_dim):
|
||||
@ -160,36 +202,58 @@ def _unwrap_batched(
|
||||
batched_outputs: Union[Tensor, Tuple[Tensor, ...]],
|
||||
out_dims: out_dims_t,
|
||||
vmap_level: int, batch_size: int, func: Callable) -> Tuple:
|
||||
flat_batched_outputs, output_spec = tree_flatten(batched_outputs)
|
||||
from torch.dict.base import TensorDictBase
|
||||
with _exclude_td_from_pytree():
|
||||
flat_batched_outputs, output_spec = tree_flatten(batched_outputs)
|
||||
|
||||
for out in flat_batched_outputs:
|
||||
# Change here:
|
||||
if isinstance(out, (TensorDictBase, torch.Tensor)):
|
||||
continue
|
||||
raise ValueError(
|
||||
f"vmap({_get_name(func)}, ...): `{_get_name(func)}` must only return "
|
||||
f"Tensors, got type {type(out)} as a return."
|
||||
)
|
||||
|
||||
def incompatible_error():
|
||||
raise ValueError(
|
||||
f'vmap({_get_name(func)}, ..., out_dims={out_dims})(<inputs>): '
|
||||
f'out_dims is not compatible with the structure of `outputs`. '
|
||||
f'out_dims has structure {tree_flatten(out_dims)[1]} but outputs '
|
||||
f'has structure {output_spec}.')
|
||||
f"vmap({_get_name(func)}, ..., out_dims={out_dims})(<inputs>): "
|
||||
f"out_dims is not compatible with the structure of `outputs`. "
|
||||
f"out_dims has structure {tree_flatten(out_dims)[1]} but outputs "
|
||||
f"has structure {output_spec}."
|
||||
)
|
||||
|
||||
if isinstance(batched_outputs, torch.Tensor):
|
||||
# Here:
|
||||
if isinstance(batched_outputs, (TensorDictBase, torch.Tensor)):
|
||||
# Some weird edge case requires us to spell out the following
|
||||
# see test_out_dims_edge_case
|
||||
if isinstance(out_dims, int):
|
||||
flat_out_dims = [out_dims]
|
||||
elif isinstance(out_dims, tuple) and len(out_dims) == 1:
|
||||
flat_out_dims = out_dims
|
||||
elif out_dims is None:
|
||||
flat_out_dims = [out_dims]
|
||||
out_dims = out_dims[0]
|
||||
else:
|
||||
incompatible_error()
|
||||
else:
|
||||
flat_out_dims = _broadcast_to_and_flatten(out_dims, output_spec)
|
||||
if flat_out_dims is None:
|
||||
incompatible_error()
|
||||
|
||||
flat_outputs = [
|
||||
_maybe_remove_batch_dim(_get_name(func), batched_output, vmap_level, batch_size, out_dim)
|
||||
for batched_output, out_dim in zip(flat_batched_outputs, flat_out_dims)
|
||||
]
|
||||
return tree_unflatten(flat_outputs, output_spec)
|
||||
flat_outputs = []
|
||||
for batched_output, out_dim in zip(flat_batched_outputs, flat_out_dims):
|
||||
if not isinstance(batched_output, TensorDictBase):
|
||||
out = _remove_batch_dim(
|
||||
batched_output,
|
||||
vmap_level,
|
||||
batch_size,
|
||||
out_dim
|
||||
)
|
||||
else:
|
||||
out = batched_output._remove_batch_dim(
|
||||
vmap_level=vmap_level, batch_size=batch_size, out_dim=out_dim
|
||||
)
|
||||
flat_outputs.append(out)
|
||||
with _exclude_td_from_pytree():
|
||||
return tree_unflatten(flat_outputs, output_spec)
|
||||
|
||||
|
||||
def _check_int_or_none(x, func, out_dims):
|
||||
|
||||
7
torch/dict/__init__.py
Normal file
7
torch/dict/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
from .base import TensorDictBase
|
||||
from .functional import dense_stack_tds, merge_tensordicts, pad, pad_sequence
|
||||
from .params import TensorDictParams
|
||||
from .tensorclass import tensorclass
|
||||
from .tensordict import TensorDict
|
||||
from ._pytree import *
|
||||
from ._lazy import LazyStackedTensorDict
|
||||
0
torch/dict/_keys.py
Normal file
0
torch/dict/_keys.py
Normal file
2151
torch/dict/_lazy.py
Normal file
2151
torch/dict/_lazy.py
Normal file
File diff suppressed because it is too large
Load Diff
704
torch/dict/_memmap.py
Normal file
704
torch/dict/_memmap.py
Normal file
@ -0,0 +1,704 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import mmap
|
||||
import os
|
||||
|
||||
import sys
|
||||
import tempfile
|
||||
from multiprocessing import util
|
||||
from multiprocessing.context import reduction
|
||||
from pathlib import Path
|
||||
from typing import Any, overload
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
|
||||
from torch import distributed as dist
|
||||
|
||||
from torch.multiprocessing.reductions import ForkingPickler
|
||||
|
||||
from .utils import implement_for
|
||||
|
||||
try:
|
||||
if dist.is_available():
|
||||
from torch.distributed._tensor.api import DTensor
|
||||
else:
|
||||
raise ImportError
|
||||
except ImportError:
|
||||
|
||||
class DTensor(torch.Tensor): # noqa: D101
|
||||
...
|
||||
|
||||
|
||||
class MemoryMappedTensor(torch.Tensor):
|
||||
"""A Memory-mapped Tensor.
|
||||
|
||||
Supports filenames or file handlers.
|
||||
|
||||
The main advantage of MemoryMappedTensor resides in its serialization methods,
|
||||
which ensure that the tensor is passed through queues or RPC remote calls without
|
||||
any copy.
|
||||
|
||||
.. note::
|
||||
When used within RPC settings, the filepath should be accessible to both nodes.
|
||||
If it isn't the behaviour of passing a MemoryMappedTensor from one worker
|
||||
to another is undefined.
|
||||
|
||||
MemoryMappedTensor supports multiple construction methods.
|
||||
|
||||
Examples:
|
||||
>>> # from an existing tensor
|
||||
>>> tensor = torch.randn(3)
|
||||
>>> with tempfile.NamedTemporaryFile() as file:
|
||||
... memmap_tensor = MemoryMappedTensor.from_tensor(tensor, filename=file.name)
|
||||
... assert memmap_tensor.filename is not None
|
||||
>>> # if no filename is passed, a handler is used
|
||||
>>> tensor = torch.randn(3)
|
||||
>>> memmap_tensor = MemoryMappedTensor.from_tensor(tensor, filename=file.name)
|
||||
>>> assert memmap_tensor.filename is None
|
||||
>>> # one can create an empty tensor too
|
||||
>>> with tempfile.NamedTemporaryFile() as file:
|
||||
... memmap_tensor_empty = MemoryMappedTensor.empty_like(tensor, filename=file.name)
|
||||
>>> with tempfile.NamedTemporaryFile() as file:
|
||||
... memmap_tensor_zero = MemoryMappedTensor.zeros_like(tensor, filename=file.name)
|
||||
>>> with tempfile.NamedTemporaryFile() as file:
|
||||
... memmap_tensor = MemoryMappedTensor.ones_like(tensor, filename=file.name)
|
||||
"""
|
||||
|
||||
_filename: str | Path
|
||||
_handler: _FileHandler
|
||||
_clear: bool
|
||||
index: Any
|
||||
parent_shape: torch.Size
|
||||
|
||||
def __new__(
|
||||
cls,
|
||||
tensor_or_file,
|
||||
*,
|
||||
dtype=None,
|
||||
shape=None,
|
||||
index=None,
|
||||
device=None,
|
||||
handler=None,
|
||||
):
|
||||
if device is not None and torch.device(device).type != "cpu":
|
||||
raise ValueError(f"{cls} device must be cpu!")
|
||||
if isinstance(tensor_or_file, str):
|
||||
return cls.from_filename(
|
||||
tensor_or_file,
|
||||
dtype,
|
||||
shape,
|
||||
index,
|
||||
)
|
||||
elif handler is not None:
|
||||
return cls.from_handler(
|
||||
handler,
|
||||
dtype,
|
||||
shape,
|
||||
index,
|
||||
)
|
||||
return super().__new__(cls, tensor_or_file)
|
||||
|
||||
def __init__(
|
||||
self, tensor_or_file, handler=None, dtype=None, shape=None, device=None
|
||||
):
|
||||
...
|
||||
|
||||
__torch_function__ = torch._C._disabled_torch_function_impl
|
||||
|
||||
@classmethod
|
||||
def from_tensor(
|
||||
cls,
|
||||
input,
|
||||
*,
|
||||
filename=None,
|
||||
existsok=False,
|
||||
copy_existing=False,
|
||||
copy_data=True,
|
||||
):
|
||||
"""Creates a MemoryMappedTensor with the same content as another tensor.
|
||||
|
||||
If the tensor is already a MemoryMappedTensor the original tensor is
|
||||
returned if the `filename` argument is `None` or if the two paths match.
|
||||
In all other cases, a new :class:`MemoryMappedTensor` is produced.
|
||||
|
||||
Args:
|
||||
input (torch.Tensor): the tensor which content must be copied onto
|
||||
the MemoryMappedTensor.
|
||||
filename (path to a file): the path to the file where the tensor
|
||||
should be stored. If none is provided, a file handler is used
|
||||
instead.
|
||||
existsok (bool, optional): if ``True``, the file will overwrite
|
||||
an existing file. Defaults to ``False``.
|
||||
copy_existing (bool, optional): if ``True`` and the provided input
|
||||
is a MemoryMappedTensor with an associated filename, copying
|
||||
the content to the new location is permitted. Otherwise an
|
||||
exception is thown. This behaviour exists to prevent
|
||||
unadvertedly duplicating data on disk.
|
||||
copy_data (bool, optional): if ``True``, the content of the tensor
|
||||
will be copied on the storage. Defaults to ``True``.
|
||||
|
||||
"""
|
||||
if isinstance(input, MemoryMappedTensor):
|
||||
if (filename is None and input._filename is None) or (
|
||||
input._filename is not None
|
||||
and filename is not None
|
||||
and Path(filename).absolute() == Path(input.filename).absolute()
|
||||
):
|
||||
# either location was not specified, or memmap is already in the
|
||||
# correct location, so just return the MemmapTensor unmodified
|
||||
return input
|
||||
elif not copy_existing and (
|
||||
input._filename is not None
|
||||
and filename is not None
|
||||
and Path(filename).absolute() != Path(input.filename).absolute()
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"A filename was provided but the tensor already has a file associated "
|
||||
f"({input.filename}). "
|
||||
f"To copy the tensor onto the new location, pass copy_existing=True."
|
||||
)
|
||||
elif isinstance(input, np.ndarray):
|
||||
raise TypeError(
|
||||
"Convert input to torch.Tensor before calling MemoryMappedTensor.from_tensor."
|
||||
)
|
||||
if input.requires_grad:
|
||||
raise RuntimeError(
|
||||
"MemoryMappedTensor.from_tensor is incompatible with tensor.requires_grad."
|
||||
)
|
||||
shape = input.shape
|
||||
if filename is None:
|
||||
if input.dtype.is_floating_point:
|
||||
size = torch.finfo(input.dtype).bits // 8 * shape.numel()
|
||||
elif input.dtype.is_complex:
|
||||
raise ValueError(
|
||||
"Complex-valued tensors are not supported by MemoryMappedTensor."
|
||||
)
|
||||
elif input.dtype == torch.bool:
|
||||
size = shape.numel()
|
||||
else:
|
||||
# assume integer
|
||||
size = torch.iinfo(input.dtype).bits // 8 * shape.numel()
|
||||
handler = _FileHandler(size)
|
||||
out = torch.frombuffer(memoryview(handler.buffer), dtype=input.dtype)
|
||||
out = out.view(shape)
|
||||
out = cls(out)
|
||||
else:
|
||||
handler = None
|
||||
if not existsok and os.path.exists(str(filename)):
|
||||
raise RuntimeError(f"The file {filename} already exists.")
|
||||
out = cls(
|
||||
torch.from_file(
|
||||
str(filename), shared=True, dtype=input.dtype, size=shape.numel()
|
||||
).view(input.shape)
|
||||
)
|
||||
out._handler = handler
|
||||
out._filename = filename
|
||||
out.index = None
|
||||
out.parent_shape = input.shape
|
||||
if copy_data:
|
||||
if isinstance(input, DTensor):
|
||||
input = input.full_tensor()
|
||||
out.copy_(input)
|
||||
return out
|
||||
|
||||
@property
|
||||
def filename(self):
|
||||
"""The filename of the tensor, if it has one.
|
||||
|
||||
Raises an exception otherwise.
|
||||
"""
|
||||
filename = self._filename
|
||||
if filename is None:
|
||||
raise RuntimeError("The MemoryMappedTensor has no file associated.")
|
||||
return filename
|
||||
|
||||
@classmethod
|
||||
def empty_like(cls, input, *, filename=None):
|
||||
# noqa: D417
|
||||
"""Creates a tensor with no content but the same shape and dtype as the input tensor.
|
||||
|
||||
Args:
|
||||
input (torch.Tensor): the tensor to use as an example.
|
||||
|
||||
Keyword Args:
|
||||
filename (path or equivalent): the path to the file, if any. If none
|
||||
is provided, a handler is used.
|
||||
"""
|
||||
return cls.from_tensor(
|
||||
torch.zeros((), dtype=input.dtype, device=input.device).expand_as(input),
|
||||
filename=filename,
|
||||
copy_data=False,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def full_like(cls, input, fill_value, *, filename=None):
|
||||
# noqa: D417
|
||||
"""Creates a tensor with a single content indicated by the `fill_value` argument, but the same shape and dtype as the input tensor.
|
||||
|
||||
Args:
|
||||
input (torch.Tensor): the tensor to use as an example.
|
||||
fill_value (float or equivalent): content of the tensor.
|
||||
|
||||
Keyword Args:
|
||||
filename (path or equivalent): the path to the file, if any. If none
|
||||
is provided, a handler is used.
|
||||
"""
|
||||
return cls.from_tensor(
|
||||
torch.zeros((), dtype=input.dtype, device=input.device).expand_as(input),
|
||||
filename=filename,
|
||||
copy_data=False,
|
||||
).fill_(fill_value)
|
||||
|
||||
@classmethod
|
||||
def zeros_like(cls, input, *, filename=None):
|
||||
# noqa: D417
|
||||
"""Creates a tensor with a 0-filled content, but the same shape and dtype as the input tensor.
|
||||
|
||||
Args:
|
||||
input (torch.Tensor): the tensor to use as an example.
|
||||
|
||||
Keyword Args:
|
||||
filename (path or equivalent): the path to the file, if any. If none
|
||||
is provided, a handler is used.
|
||||
"""
|
||||
return cls.from_tensor(
|
||||
torch.zeros((), dtype=input.dtype, device=input.device).expand_as(input),
|
||||
filename=filename,
|
||||
copy_data=False,
|
||||
).fill_(0.0)
|
||||
|
||||
@classmethod
|
||||
def ones_like(cls, input, *, filename=None):
|
||||
# noqa: D417
|
||||
"""Creates a tensor with a 1-filled content, but the same shape and dtype as the input tensor.
|
||||
|
||||
Args:
|
||||
input (torch.Tensor): the tensor to use as an example.
|
||||
|
||||
Keyword Args:
|
||||
filename (path or equivalent): the path to the file, if any. If none
|
||||
is provided, a handler is used.
|
||||
"""
|
||||
return cls.from_tensor(
|
||||
torch.ones((), dtype=input.dtype, device=input.device).expand_as(input),
|
||||
filename=filename,
|
||||
copy_data=False,
|
||||
).fill_(1.0)
|
||||
|
||||
@classmethod
|
||||
@overload
|
||||
def ones(cls, *size, dtype=None, device=None, filename=None):
|
||||
...
|
||||
|
||||
@classmethod
|
||||
@overload
|
||||
def ones(cls, shape, *, dtype=None, device=None, filename=None):
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def ones(cls, *args, **kwargs):
|
||||
# noqa: D417
|
||||
"""Creates a tensor with a 1-filled content, specific shape, dtype and filename.
|
||||
|
||||
Args:
|
||||
shape (integers or torch.Size): the shape of the tensor.
|
||||
|
||||
Keyword Args:
|
||||
dtype (torch.dtype): the dtype of the tensor.
|
||||
device (torch.device): the device of the tensor. Only `None` and `"cpu"`
|
||||
are accepted, any other device will raise an exception.
|
||||
filename (path or equivalent): the path to the file, if any. If none
|
||||
is provided, a handler is used.
|
||||
"""
|
||||
shape, device, dtype, _, filename = _proc_args_const(*args, **kwargs)
|
||||
if device is not None:
|
||||
device = torch.device(device)
|
||||
if device.type != "cpu":
|
||||
raise RuntimeError("Only CPU tensors are supported.")
|
||||
result = torch.ones((), dtype=dtype, device=device)
|
||||
if shape:
|
||||
if isinstance(shape[0], (list, tuple)) and len(shape) == 1:
|
||||
shape = torch.Size(shape[0])
|
||||
else:
|
||||
shape = torch.Size(shape)
|
||||
result = result.expand(shape)
|
||||
return cls.from_tensor(
|
||||
result,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@overload
|
||||
def zeros(cls, *size, dtype=None, device=None, filename=None):
|
||||
...
|
||||
|
||||
@classmethod
|
||||
@overload
|
||||
def zeros(cls, shape, *, dtype=None, device=None, filename=None):
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def zeros(cls, *args, **kwargs):
|
||||
# noqa: D417
|
||||
"""Creates a tensor with a 0-filled content, specific shape, dtype and filename.
|
||||
|
||||
Args:
|
||||
shape (integers or torch.Size): the shape of the tensor.
|
||||
|
||||
Keyword Args:
|
||||
dtype (torch.dtype): the dtype of the tensor.
|
||||
device (torch.device): the device of the tensor. Only `None` and `"cpu"`
|
||||
are accepted, any other device will raise an exception.
|
||||
filename (path or equivalent): the path to the file, if any. If none
|
||||
is provided, a handler is used.
|
||||
"""
|
||||
shape, device, dtype, _, filename = _proc_args_const(*args, **kwargs)
|
||||
if device is not None:
|
||||
device = torch.device(device)
|
||||
if device.type != "cpu":
|
||||
raise RuntimeError("Only CPU tensors are supported.")
|
||||
result = torch.zeros((), dtype=dtype, device=device)
|
||||
if shape:
|
||||
if isinstance(shape[0], (list, tuple)) and len(shape) == 1:
|
||||
shape = torch.Size(shape[0])
|
||||
else:
|
||||
shape = torch.Size(shape)
|
||||
result = result.expand(shape)
|
||||
result = cls.from_tensor(
|
||||
result,
|
||||
filename=filename,
|
||||
)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
@overload
|
||||
def empty(cls, *size, dtype=None, device=None, filename=None):
|
||||
...
|
||||
|
||||
@classmethod
|
||||
@overload
|
||||
def empty(cls, shape, *, dtype=None, device=None, filename=None):
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def empty(cls, *args, **kwargs):
|
||||
# noqa: D417
|
||||
"""Creates a tensor with empty content, specific shape, dtype and filename.
|
||||
|
||||
Args:
|
||||
shape (integers or torch.Size): the shape of the tensor.
|
||||
|
||||
Keyword Args:
|
||||
dtype (torch.dtype): the dtype of the tensor.
|
||||
device (torch.device): the device of the tensor. Only `None` and `"cpu"`
|
||||
are accepted, any other device will raise an exception.
|
||||
filename (path or equivalent): the path to the file, if any. If none
|
||||
is provided, a handler is used.
|
||||
"""
|
||||
shape, device, dtype, _, filename = _proc_args_const(*args, **kwargs)
|
||||
if device is not None:
|
||||
device = torch.device(device)
|
||||
if device.type != "cpu":
|
||||
raise RuntimeError("Only CPU tensors are supported.")
|
||||
result = torch.zeros((), dtype=dtype, device=device)
|
||||
if shape:
|
||||
if isinstance(shape[0], (list, tuple)) and len(shape) == 1:
|
||||
shape = torch.Size(shape[0])
|
||||
else:
|
||||
shape = torch.Size(shape)
|
||||
result = result.expand(shape)
|
||||
result = cls.from_tensor(result, filename=filename)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
@overload
|
||||
def full(cls, *size, fill_value, dtype=None, device=None, filename=None):
|
||||
...
|
||||
|
||||
@classmethod
|
||||
@overload
|
||||
def full(cls, shape, *, fill_value, dtype=None, device=None, filename=None):
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def full(cls, *args, **kwargs):
|
||||
# noqa: D417
|
||||
"""Creates a tensor with a single content specified by `fill_value`, specific shape, dtype and filename.
|
||||
|
||||
Args:
|
||||
shape (integers or torch.Size): the shape of the tensor.
|
||||
|
||||
Keyword Args:
|
||||
fill_value (float or equivalent): content of the tensor.
|
||||
dtype (torch.dtype): the dtype of the tensor.
|
||||
device (torch.device): the device of the tensor. Only `None` and `"cpu"`
|
||||
are accepted, any other device will raise an exception.
|
||||
filename (path or equivalent): the path to the file, if any. If none
|
||||
is provided, a handler is used.
|
||||
"""
|
||||
shape, device, dtype, fill_value, filename = _proc_args_const(*args, **kwargs)
|
||||
if device is not None:
|
||||
device = torch.device(device)
|
||||
if device.type != "cpu":
|
||||
raise RuntimeError("Only CPU tensors are supported.")
|
||||
result = torch.zeros((), dtype=dtype, device=device).fill_(fill_value)
|
||||
if shape:
|
||||
if isinstance(shape[0], (list, tuple)) and len(shape) == 1:
|
||||
shape = torch.Size(shape[0])
|
||||
else:
|
||||
shape = torch.Size(shape)
|
||||
result = result.expand(shape)
|
||||
return cls.from_tensor(result, filename=filename)
|
||||
|
||||
@classmethod
|
||||
def from_filename(cls, filename, dtype, shape, index=None):
|
||||
# noqa: D417
|
||||
"""Loads a MemoryMappedTensor from a given filename.
|
||||
|
||||
Args:
|
||||
filename (path or equivalent): the path to the file.
|
||||
dtype (torch.dtype): the dtype of the tensor.
|
||||
shape (integers or torch.Size): the shape of the tensor.
|
||||
index (torch-compatible index type): an index to use to build the
|
||||
tensor.
|
||||
|
||||
"""
|
||||
shape = torch.Size(shape)
|
||||
tensor = torch.from_file(
|
||||
str(filename), shared=True, dtype=dtype, size=shape.numel()
|
||||
).view(shape)
|
||||
if index is not None:
|
||||
tensor = tensor[index]
|
||||
out = cls(tensor)
|
||||
out._filename = filename
|
||||
out._handler = None
|
||||
out.index = index
|
||||
out.parent_shape = shape
|
||||
return out
|
||||
|
||||
@classmethod
|
||||
def from_handler(cls, handler, dtype, shape, index):
|
||||
# noqa: D417
|
||||
"""Loads a MemoryMappedTensor from a given handler.
|
||||
|
||||
Args:
|
||||
handler (compatible file handler): the handler for the tensor.
|
||||
dtype (torch.dtype): the dtype of the tensor.
|
||||
shape (integers or torch.Size): the shape of the tensor.
|
||||
index (torch-compatible index type): an index to use to build the
|
||||
tensor.
|
||||
|
||||
"""
|
||||
shape = torch.Size(shape)
|
||||
out = torch.frombuffer(memoryview(handler.buffer), dtype=dtype)
|
||||
out = torch.reshape(out, shape)
|
||||
if index is not None:
|
||||
out = out[index]
|
||||
out = cls(out)
|
||||
out._filename = None
|
||||
out._handler = handler
|
||||
out.index = index
|
||||
out.parent_shape = shape
|
||||
return out
|
||||
|
||||
@property
|
||||
def _tensor(self):
|
||||
# for bc-compatibility with MemmapTensor, to be deprecated in v0.4
|
||||
return self
|
||||
|
||||
def __setstate__(self, state):
|
||||
if "filename" in state:
|
||||
self.__dict__ = type(self).from_filename(**state).__dict__
|
||||
else:
|
||||
self.__dict__ = type(self).from_handler(**state).__dict__
|
||||
|
||||
def __getstate__(self):
|
||||
if getattr(self, "_handler", None) is not None:
|
||||
return {
|
||||
"handler": self._handler,
|
||||
"dtype": self.dtype,
|
||||
"shape": self.parent_shape,
|
||||
"index": self.index,
|
||||
}
|
||||
elif getattr(self, "_filename", None) is not None:
|
||||
return {
|
||||
"filename": self._filename,
|
||||
"dtype": self.dtype,
|
||||
"shape": self.parent_shape,
|
||||
"index": self.index,
|
||||
}
|
||||
else:
|
||||
raise RuntimeError("Could not find handler or filename.")
|
||||
|
||||
def __reduce_ex__(self, protocol):
|
||||
return self.__reduce__()
|
||||
|
||||
def __reduce__(self):
|
||||
if getattr(self, "_handler", None) is not None:
|
||||
return type(self).from_handler, (
|
||||
self._handler,
|
||||
self.dtype,
|
||||
self.parent_shape,
|
||||
self.index,
|
||||
)
|
||||
elif getattr(self, "_filename", None) is not None:
|
||||
return type(self).from_filename, (
|
||||
self._filename,
|
||||
self.dtype,
|
||||
self.parent_shape,
|
||||
self.index,
|
||||
)
|
||||
else:
|
||||
raise RuntimeError("Could not find handler or filename.")
|
||||
|
||||
@implement_for("torch", "2.0", None)
|
||||
def __getitem__(self, item):
|
||||
try:
|
||||
out = super().__getitem__(item)
|
||||
except ValueError as err:
|
||||
if "is unbound" in str(err):
|
||||
raise ValueError(
|
||||
"Using first class dimension indices with MemoryMappedTensor "
|
||||
"isn't supported at the moment."
|
||||
) from err
|
||||
raise
|
||||
if out.untyped_storage().data_ptr() == self.untyped_storage().data_ptr():
|
||||
out = MemoryMappedTensor(out)
|
||||
out._handler = self._handler
|
||||
out._filename = self._filename
|
||||
out.index = item
|
||||
out.parent_shape = self.parent_shape
|
||||
return out
|
||||
|
||||
@implement_for("torch", None, "2.0")
|
||||
def __getitem__(self, item): # noqa: F811
|
||||
try:
|
||||
out = super().__getitem__(item)
|
||||
except ValueError as err:
|
||||
if "is unbound" in str(err):
|
||||
raise ValueError(
|
||||
"Using first class dimension indices with MemoryMappedTensor "
|
||||
"isn't supported at the moment."
|
||||
) from err
|
||||
raise
|
||||
if out.storage().data_ptr() == self.storage().data_ptr():
|
||||
out = MemoryMappedTensor(out)
|
||||
out._handler = self._handler
|
||||
out._filename = self._filename
|
||||
out.index = item
|
||||
out.parent_shape = self.parent_shape
|
||||
return out
|
||||
|
||||
|
||||
#####################
|
||||
# File handler
|
||||
# borrowed from mp.heap
|
||||
|
||||
if sys.platform == "win32":
|
||||
import _winapi
|
||||
|
||||
class _FileHandler:
|
||||
_rand = tempfile._RandomNameSequence()
|
||||
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
for _ in range(100):
|
||||
name = "pym-%d-%s" % (os.getpid(), next(self._rand))
|
||||
buf = mmap.mmap(-1, size, tagname=name)
|
||||
if _winapi.GetLastError() == 0:
|
||||
break
|
||||
# We have reopened a preexisting mmap.
|
||||
buf.close()
|
||||
else:
|
||||
raise FileExistsError("Cannot find name for new mmap")
|
||||
self.name = name
|
||||
self.buffer = buf
|
||||
self._state = (self.size, self.name)
|
||||
|
||||
def __getstate__(self):
|
||||
from multiprocessing.context import assert_spawning
|
||||
|
||||
assert_spawning(self)
|
||||
return self._state
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.size, self.name = self._state = state
|
||||
# Reopen existing mmap
|
||||
self.buffer = mmap.mmap(-1, self.size, tagname=self.name)
|
||||
# XXX Temporarily preventing buildbot failures while determining
|
||||
# XXX the correct long-term fix. See issue 23060
|
||||
# assert _winapi.GetLastError() == _winapi.ERROR_ALREADY_EXISTS
|
||||
|
||||
else:
|
||||
|
||||
class _FileHandler:
|
||||
if sys.platform == "linux":
|
||||
_dir_candidates = ["/dev/shm"]
|
||||
else:
|
||||
_dir_candidates = []
|
||||
|
||||
def __init__(self, size, fd=-1):
|
||||
self.size = size
|
||||
self.fd = fd
|
||||
if fd == -1:
|
||||
self.fd, name = tempfile.mkstemp(
|
||||
prefix="pym-%d-" % os.getpid(), dir=self._choose_dir(size)
|
||||
)
|
||||
os.unlink(name)
|
||||
util.Finalize(self, os.close, (self.fd,))
|
||||
os.ftruncate(self.fd, size)
|
||||
self.buffer = mmap.mmap(self.fd, self.size)
|
||||
|
||||
def _choose_dir(self, size):
|
||||
# Choose a non-storage backed directory if possible,
|
||||
# to improve performance
|
||||
for d in self._dir_candidates:
|
||||
st = os.statvfs(d)
|
||||
if st.f_bavail * st.f_frsize >= size: # enough free space?
|
||||
return d
|
||||
return util.get_temp_dir()
|
||||
|
||||
def _reduce_handler(handler):
|
||||
if handler.fd == -1:
|
||||
raise ValueError(
|
||||
"Handler is unpicklable because "
|
||||
"forking was enabled when it was created"
|
||||
)
|
||||
return _rebuild_handler, (handler.size, reduction.DupFd(handler.fd))
|
||||
|
||||
def _rebuild_handler(size, dupfd):
|
||||
detached = dupfd.detach()
|
||||
return _FileHandler(size, detached)
|
||||
|
||||
reduction.register(_FileHandler, _reduce_handler)
|
||||
|
||||
|
||||
def _reduce_memmap(memmap_tensor):
|
||||
return memmap_tensor.__reduce__()
|
||||
|
||||
|
||||
ForkingPickler.register(MemoryMappedTensor, _reduce_memmap)
|
||||
|
||||
|
||||
def _proc_args_const(*args, **kwargs):
|
||||
if len(args) > 0:
|
||||
# then the first (or the N first) args are the shape
|
||||
if len(args) == 1 and not isinstance(args[0], int):
|
||||
shape = torch.Size(args[0])
|
||||
else:
|
||||
shape = torch.Size(args)
|
||||
else:
|
||||
# we should have a "shape" keyword arg
|
||||
shape = kwargs.pop("shape", None)
|
||||
if shape is None:
|
||||
raise TypeError("Could not find the shape argument in the arguments.")
|
||||
shape = torch.Size(shape)
|
||||
return (
|
||||
shape,
|
||||
kwargs.pop("device", None),
|
||||
kwargs.pop("dtype", None),
|
||||
kwargs.pop("fill_value", None),
|
||||
kwargs.pop("filename", None),
|
||||
)
|
||||
89
torch/dict/_pytree.py
Normal file
89
torch/dict/_pytree.py
Normal file
@ -0,0 +1,89 @@
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from torch.dict import TensorDictParams
|
||||
from torch.dict.tensordict import _SubTensorDict, TensorDict
|
||||
from torch.utils._pytree import Context, register_pytree_node
|
||||
|
||||
PYTREE_REGISTERED_TDS = (
|
||||
TensorDict,
|
||||
TensorDictParams,
|
||||
_SubTensorDict,
|
||||
)
|
||||
|
||||
__all__ = []
|
||||
|
||||
|
||||
def _str_to_dict(str_spec: str) -> Tuple[List[str], str]:
|
||||
assert str_spec[1] == "("
|
||||
assert str_spec[-1] == ")"
|
||||
context_and_child_strings = str_spec[2:-1]
|
||||
|
||||
child_strings = []
|
||||
context_strings = []
|
||||
nested_parentheses = 0
|
||||
start_index = 0
|
||||
for i, char in enumerate(context_and_child_strings):
|
||||
if char == ":":
|
||||
if nested_parentheses == 0:
|
||||
context_strings.append(context_and_child_strings[start_index:i])
|
||||
start_index = i + 1
|
||||
elif char == "(":
|
||||
nested_parentheses += 1
|
||||
elif char == ")":
|
||||
nested_parentheses -= 1
|
||||
|
||||
if nested_parentheses == 0 and char == ",":
|
||||
child_strings.append(context_and_child_strings[start_index:i])
|
||||
start_index = i + 1
|
||||
|
||||
child_strings.append(context_and_child_strings[start_index:])
|
||||
return context_strings, ",".join(child_strings)
|
||||
|
||||
|
||||
def _str_to_tensordictdict(str_spec: str) -> Tuple[List[str], str]:
|
||||
context_and_child_strings = str_spec[2:-1]
|
||||
|
||||
child_strings = []
|
||||
context_strings = []
|
||||
nested_parentheses = 0
|
||||
start_index = 0
|
||||
for i, char in enumerate(context_and_child_strings):
|
||||
if char == ":":
|
||||
if nested_parentheses == 0:
|
||||
context_strings.append(context_and_child_strings[start_index:i])
|
||||
start_index = i + 1
|
||||
elif char == "(":
|
||||
nested_parentheses += 1
|
||||
elif char == ")":
|
||||
nested_parentheses -= 1
|
||||
|
||||
if nested_parentheses == 0 and char == ",":
|
||||
child_strings.append(context_and_child_strings[start_index:i])
|
||||
start_index = i + 1
|
||||
|
||||
child_strings.append(context_and_child_strings[start_index:])
|
||||
return context_strings, ",".join(child_strings)
|
||||
|
||||
|
||||
def _tensordict_flatten(d: TensorDict) -> Tuple[List[Any], Context]:
|
||||
return list(d.values()), {
|
||||
"keys": list(d.keys()),
|
||||
"batch_size": d.batch_size,
|
||||
"names": d.names,
|
||||
}
|
||||
|
||||
|
||||
def _tensordictdict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]:
|
||||
return TensorDict(
|
||||
dict(zip(context["keys"], values)),
|
||||
context["batch_size"],
|
||||
names=context["names"],
|
||||
)
|
||||
|
||||
|
||||
for cls in PYTREE_REGISTERED_TDS:
|
||||
register_pytree_node(
|
||||
cls,
|
||||
_tensordict_flatten,
|
||||
_tensordictdict_unflatten,
|
||||
)
|
||||
419
torch/dict/_torch_func.py
Normal file
419
torch/dict/_torch_func.py
Normal file
@ -0,0 +1,419 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
from typing import Any, Callable, Sequence, TypeVar
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from ._lazy import LazyStackedTensorDict
|
||||
from .base import NO_DEFAULT, TensorDictBase
|
||||
from .tensordict import TensorDict
|
||||
from .utils import _check_keys, _ErrorInteceptor, DeviceType
|
||||
|
||||
|
||||
TD_HANDLED_FUNCTIONS: dict[Callable, Callable] = {}
|
||||
LAZY_TD_HANDLED_FUNCTIONS: dict[Callable, Callable] = {}
|
||||
T = TypeVar("T", bound="TensorDictBase")
|
||||
|
||||
|
||||
def implements_for_td(torch_function: Callable) -> Callable[[Callable], Callable]:
|
||||
"""Register a torch function override for TensorDict."""
|
||||
|
||||
@functools.wraps(torch_function)
|
||||
def decorator(func: Callable) -> Callable:
|
||||
TD_HANDLED_FUNCTIONS[torch_function] = func
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def implements_for_lazy_td(torch_function: Callable) -> Callable[[Callable], Callable]:
|
||||
"""Register a torch function override for TensorDict."""
|
||||
|
||||
@functools.wraps(torch_function)
|
||||
def decorator(func: Callable) -> Callable:
|
||||
LAZY_TD_HANDLED_FUNCTIONS[torch_function] = func
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@implements_for_td(torch.unbind)
|
||||
def _unbind(td: T, *args: Any, **kwargs: Any) -> tuple[T, ...]:
|
||||
return td.unbind(*args, **kwargs)
|
||||
|
||||
|
||||
@implements_for_td(torch.gather)
|
||||
def _gather(
|
||||
input: T,
|
||||
dim: int,
|
||||
index: Tensor,
|
||||
*,
|
||||
sparse_grad: bool = False,
|
||||
out: T | None = None,
|
||||
) -> T:
|
||||
if sparse_grad:
|
||||
raise NotImplementedError(
|
||||
"sparse_grad=True not implemented for torch.gather(tensordict, ...)"
|
||||
)
|
||||
# the index must have as many dims as the tensordict
|
||||
if not len(index):
|
||||
raise RuntimeError("Cannot use torch.gather with an empty index")
|
||||
dim_orig = dim
|
||||
if dim < 0:
|
||||
dim = input.batch_dims + dim
|
||||
if dim > input.batch_dims - 1 or dim < 0:
|
||||
raise RuntimeError(
|
||||
f"Cannot gather tensordict with shape {input.shape} along dim {dim_orig}."
|
||||
)
|
||||
|
||||
def _gather_tensor(tensor, dest=None):
|
||||
index_expand = index
|
||||
while index_expand.ndim < tensor.ndim:
|
||||
index_expand = index_expand.unsqueeze(-1)
|
||||
target_shape = list(tensor.shape)
|
||||
target_shape[dim] = index_expand.shape[dim]
|
||||
index_expand = index_expand.expand(target_shape)
|
||||
out = torch.gather(tensor, dim, index_expand, out=dest)
|
||||
return out
|
||||
|
||||
if out is None:
|
||||
names = input.names if input._has_names() else None
|
||||
|
||||
return TensorDict(
|
||||
{key: _gather_tensor(value) for key, value in input.items()},
|
||||
batch_size=index.shape,
|
||||
names=names,
|
||||
)
|
||||
TensorDict(
|
||||
{key: _gather_tensor(value, out[key]) for key, value in input.items()},
|
||||
batch_size=index.shape,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
@implements_for_td(torch.full_like)
|
||||
def _full_like(td: T, fill_value: float, **kwargs: Any) -> T:
|
||||
td_clone = td.clone()
|
||||
for key in td_clone.keys():
|
||||
td_clone.fill_(key, fill_value)
|
||||
if "dtype" in kwargs:
|
||||
raise ValueError("Cannot pass dtype to full_like with TensorDict")
|
||||
if "device" in kwargs:
|
||||
td_clone = td_clone.to(kwargs.pop("device"))
|
||||
if len(kwargs):
|
||||
raise RuntimeError(
|
||||
f"keyword arguments {list(kwargs.keys())} are not "
|
||||
f"supported with full_like with TensorDict"
|
||||
)
|
||||
return td_clone
|
||||
|
||||
|
||||
@implements_for_td(torch.zeros_like)
|
||||
def _zeros_like(td: T, **kwargs: Any) -> T:
|
||||
td_clone = td._fast_apply(torch.zeros_like)
|
||||
if "dtype" in kwargs:
|
||||
raise ValueError("Cannot pass dtype to full_like with TensorDict")
|
||||
if "device" in kwargs:
|
||||
td_clone = td_clone.to(kwargs.pop("device"))
|
||||
if len(kwargs):
|
||||
raise RuntimeError(
|
||||
f"keyword arguments {list(kwargs.keys())} are not "
|
||||
f"supported with full_like with TensorDict"
|
||||
)
|
||||
return td_clone
|
||||
|
||||
|
||||
@implements_for_td(torch.ones_like)
|
||||
def _ones_like(td: T, **kwargs: Any) -> T:
|
||||
td_clone = td._fast_apply(lambda x: torch.ones_like(x))
|
||||
if "device" in kwargs:
|
||||
td_clone = td_clone.to(kwargs.pop("device"))
|
||||
if len(kwargs):
|
||||
raise RuntimeError(
|
||||
f"keyword arguments {list(kwargs.keys())} are not "
|
||||
f"supported with full_like with TensorDict"
|
||||
)
|
||||
return td_clone
|
||||
|
||||
|
||||
@implements_for_td(torch.empty_like)
|
||||
def _empty_like(td: T, *args, **kwargs) -> T:
|
||||
try:
|
||||
tdclone = td.clone()
|
||||
except Exception as err:
|
||||
raise RuntimeError(
|
||||
"The tensordict passed to torch.empty_like cannot be "
|
||||
"cloned, preventing empty_like to be called. "
|
||||
"Consider calling tensordict.to_tensordict() first."
|
||||
) from err
|
||||
return tdclone._fast_apply(
|
||||
lambda x: torch.empty_like(x, *args, **kwargs), inplace=True
|
||||
)
|
||||
|
||||
|
||||
@implements_for_td(torch.clone)
|
||||
def _clone(td: T, *args: Any, **kwargs: Any) -> T:
|
||||
return td.clone(*args, **kwargs)
|
||||
|
||||
|
||||
@implements_for_td(torch.squeeze)
|
||||
def _squeeze(td: T, *args: Any, **kwargs: Any) -> T:
|
||||
return td.squeeze(*args, **kwargs)
|
||||
|
||||
|
||||
@implements_for_td(torch.unsqueeze)
|
||||
def _unsqueeze(td: T, *args: Any, **kwargs: Any) -> T:
|
||||
return td.unsqueeze(*args, **kwargs)
|
||||
|
||||
|
||||
@implements_for_td(torch.masked_select)
|
||||
def _masked_select(td: T, *args: Any, **kwargs: Any) -> T:
|
||||
return td.masked_select(*args, **kwargs)
|
||||
|
||||
|
||||
@implements_for_td(torch.permute)
|
||||
def _permute(td: T, dims: Sequence[int]) -> T:
|
||||
return td.permute(*dims)
|
||||
|
||||
|
||||
@implements_for_td(torch.cat)
|
||||
def _cat(
|
||||
list_of_tensordicts: Sequence[T],
|
||||
dim: int = 0,
|
||||
device: DeviceType | None = None,
|
||||
out: T | None = None,
|
||||
) -> T:
|
||||
if not list_of_tensordicts:
|
||||
raise RuntimeError("list_of_tensordicts cannot be empty")
|
||||
|
||||
batch_size = list(list_of_tensordicts[0].batch_size)
|
||||
if dim < 0:
|
||||
dim = len(batch_size) + dim
|
||||
if dim >= len(batch_size):
|
||||
raise RuntimeError(
|
||||
f"dim must be in the range 0 <= dim < len(batch_size), got dim"
|
||||
f"={dim} and batch_size={batch_size}"
|
||||
)
|
||||
batch_size[dim] = sum([td.batch_size[dim] for td in list_of_tensordicts])
|
||||
batch_size = torch.Size(batch_size)
|
||||
|
||||
# check that all tensordict match
|
||||
keys = _check_keys(list_of_tensordicts, strict=True)
|
||||
if out is None:
|
||||
out = {}
|
||||
for key in keys:
|
||||
with _ErrorInteceptor(
|
||||
key, "Attempted to concatenate tensors on different devices at key"
|
||||
):
|
||||
out[key] = torch.cat(
|
||||
[td._get_str(key, NO_DEFAULT) for td in list_of_tensordicts], dim
|
||||
)
|
||||
if device is None:
|
||||
device = list_of_tensordicts[0].device
|
||||
for td in list_of_tensordicts[1:]:
|
||||
if device == td.device:
|
||||
continue
|
||||
else:
|
||||
device = None
|
||||
break
|
||||
names = None
|
||||
if list_of_tensordicts[0]._has_names():
|
||||
names = list_of_tensordicts[0].names
|
||||
return TensorDict(
|
||||
out, device=device, batch_size=batch_size, _run_checks=False, names=names
|
||||
)
|
||||
else:
|
||||
if out.batch_size != batch_size:
|
||||
raise RuntimeError(
|
||||
"out.batch_size and cat batch size must match, "
|
||||
f"got out.batch_size={out.batch_size} and batch_size"
|
||||
f"={batch_size}"
|
||||
)
|
||||
|
||||
for key in keys:
|
||||
with _ErrorInteceptor(
|
||||
key, "Attempted to concatenate tensors on different devices at key"
|
||||
):
|
||||
if isinstance(out, TensorDict):
|
||||
torch.cat(
|
||||
[td.get(key) for td in list_of_tensordicts],
|
||||
dim,
|
||||
out=out.get(key),
|
||||
)
|
||||
else:
|
||||
out.set_(
|
||||
key, torch.cat([td.get(key) for td in list_of_tensordicts], dim)
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
@implements_for_lazy_td(torch.cat)
|
||||
def _lazy_cat(
|
||||
list_of_tensordicts: Sequence[LazyStackedTensorDict],
|
||||
dim: int = 0,
|
||||
out: LazyStackedTensorDict | None = None,
|
||||
) -> LazyStackedTensorDict:
|
||||
# why aren't they feeding you?
|
||||
if not list_of_tensordicts:
|
||||
raise RuntimeError("list_of_tensordicts cannot be empty")
|
||||
|
||||
batch_size = list(list_of_tensordicts[0].batch_size)
|
||||
if dim < 0:
|
||||
dim = len(batch_size) + dim
|
||||
if dim >= len(batch_size):
|
||||
raise RuntimeError(
|
||||
f"dim must be in the range 0 <= dim < len(batch_size), got dim"
|
||||
f"={dim} and batch_size={batch_size}"
|
||||
)
|
||||
stack_dim = list_of_tensordicts[0].stack_dim
|
||||
if any((td.stack_dim != stack_dim) for td in list_of_tensordicts):
|
||||
raise RuntimeError("cat lazy stacked tds must have same stack dim")
|
||||
|
||||
batch_size[dim] = sum(td.batch_size[dim] for td in list_of_tensordicts)
|
||||
batch_size = torch.Size(batch_size)
|
||||
|
||||
new_dim = dim
|
||||
if dim > stack_dim:
|
||||
new_dim = dim - 1
|
||||
|
||||
if out is None:
|
||||
out = []
|
||||
if dim == stack_dim: # if dim is stack, just add all to the same list
|
||||
for lazy_td in list_of_tensordicts:
|
||||
out += lazy_td.tensordicts
|
||||
else:
|
||||
for i in range(len(list_of_tensordicts[0].tensordicts)):
|
||||
out.append(
|
||||
torch.cat(
|
||||
[lazy_td.tensordicts[i] for lazy_td in list_of_tensordicts],
|
||||
new_dim,
|
||||
)
|
||||
)
|
||||
return LazyStackedTensorDict(*out, stack_dim=stack_dim)
|
||||
else:
|
||||
if not isinstance(out, LazyStackedTensorDict):
|
||||
return _cat(list_of_tensordicts, dim=dim, out=out)
|
||||
|
||||
if out.batch_size != batch_size:
|
||||
raise RuntimeError(
|
||||
"out.batch_size and cat batch size must match, "
|
||||
f"got out.batch_size={out.batch_size} and batch_size"
|
||||
f"={batch_size}"
|
||||
)
|
||||
if out.stack_dim != dim:
|
||||
index_base = (slice(None),) * out.stack_dim
|
||||
for i, sub_dest in enumerate(out.tensordicts):
|
||||
index = index_base + (i,)
|
||||
tds_to_cat = [_td[index] for _td in list_of_tensordicts]
|
||||
torch.cat(tds_to_cat, dim, out=sub_dest)
|
||||
else:
|
||||
init_idx = 0
|
||||
for td_in in list_of_tensordicts:
|
||||
sub_dest = out.tensordicts[init_idx : init_idx + td_in.shape[dim]]
|
||||
init_idx += init_idx + td_in.shape[dim]
|
||||
torch.stack(sub_dest, out.stack_dim).update(td_in, inplace=True)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@implements_for_td(torch.stack)
|
||||
def _stack(
|
||||
list_of_tensordicts: Sequence[TensorDictBase],
|
||||
dim: int = 0,
|
||||
device: DeviceType | None = None,
|
||||
out: T | None = None,
|
||||
strict: bool = False,
|
||||
contiguous: bool = False,
|
||||
) -> T:
|
||||
if not list_of_tensordicts:
|
||||
raise RuntimeError("list_of_tensordicts cannot be empty")
|
||||
batch_size = list_of_tensordicts[0].batch_size
|
||||
if dim < 0:
|
||||
dim = len(batch_size) + dim + 1
|
||||
|
||||
for td in list_of_tensordicts[1:]:
|
||||
if td.batch_size != list_of_tensordicts[0].batch_size:
|
||||
raise RuntimeError(
|
||||
"stacking tensordicts requires them to have congruent batch sizes, "
|
||||
f"got td1.batch_size={td.batch_size} and td2.batch_size="
|
||||
f"{list_of_tensordicts[0].batch_size}"
|
||||
)
|
||||
|
||||
# check that all tensordict match
|
||||
keys = _check_keys(list_of_tensordicts)
|
||||
result_batch_size = list(batch_size)
|
||||
result_batch_size.insert(dim, len(list_of_tensordicts))
|
||||
result_batch_size = torch.Size(result_batch_size)
|
||||
|
||||
if out is None:
|
||||
device = list_of_tensordicts[0].device
|
||||
out = {}
|
||||
for key in keys:
|
||||
with _ErrorInteceptor(
|
||||
key, "Attempted to stack tensors on different devices at key"
|
||||
):
|
||||
out[key] = torch.stack(
|
||||
[_tensordict.get(key) for _tensordict in list_of_tensordicts],
|
||||
dim,
|
||||
)
|
||||
return TensorDict(
|
||||
out,
|
||||
batch_size=result_batch_size,
|
||||
device=device,
|
||||
_run_checks=False,
|
||||
)
|
||||
else:
|
||||
if out.batch_size != result_batch_size:
|
||||
raise RuntimeError(
|
||||
"out.batch_size and stacked batch size must match, "
|
||||
f"got out.batch_size={out.batch_size} and resulting batch_size"
|
||||
f"={result_batch_size}"
|
||||
)
|
||||
|
||||
out_keys = set(out.keys())
|
||||
if strict:
|
||||
in_keys = set(keys)
|
||||
if len(out_keys - in_keys) > 0:
|
||||
raise RuntimeError(
|
||||
"The output tensordict has keys that are missing in the "
|
||||
"tensordict that has to be written: {out_keys - in_keys}. "
|
||||
"As per the call to `stack(..., strict=True)`, this "
|
||||
"is not permitted."
|
||||
)
|
||||
elif len(in_keys - out_keys) > 0:
|
||||
raise RuntimeError(
|
||||
"The resulting tensordict has keys that are missing in "
|
||||
f"its destination: {in_keys - out_keys}. As per the call "
|
||||
"to `stack(..., strict=True)`, this is not permitted."
|
||||
)
|
||||
|
||||
out._stack_onto_(list_of_tensordicts, dim)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@implements_for_td(torch.split)
|
||||
def _split(
|
||||
td: TensorDict, split_size_or_sections: int | list[int], dim: int = 0
|
||||
) -> list[TensorDictBase]:
|
||||
return td.split(split_size_or_sections, dim)
|
||||
|
||||
|
||||
@implements_for_td(torch.where)
|
||||
def where(condition, input, other, *, out=None):
|
||||
"""Return a ``TensorDict`` of elements selected from either input or other, depending on condition.
|
||||
|
||||
Args:
|
||||
condition (BoolTensor): When ``True`` (nonzero), yield ``input``, otherwise yield ``other``.
|
||||
input (TensorDictBase or Scalar): value (if ``input`` is a scalar) or values selected at indices where condition is ``True``.
|
||||
other (TensorDictBase or Scalar): value (if ``other`` is a scalar) or values selected at indices where condition is ``False``.
|
||||
out (Tensor, optional): the output ``TensorDictBase`` instance.
|
||||
|
||||
"""
|
||||
return input.where(condition, other, out=out)
|
||||
4500
torch/dict/base.py
Normal file
4500
torch/dict/base.py
Normal file
File diff suppressed because it is too large
Load Diff
338
torch/dict/functional.py
Normal file
338
torch/dict/functional.py
Normal file
@ -0,0 +1,338 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, TypeVar
|
||||
|
||||
import torch
|
||||
from ._lazy import LazyStackedTensorDict
|
||||
from .base import _is_tensor_collection, CompatibleType, TensorDictBase
|
||||
from .tensordict import TensorDict
|
||||
from .utils import _check_keys, _shape, DeviceType
|
||||
|
||||
T = TypeVar("T", bound="TensorDictBase")
|
||||
|
||||
|
||||
def pad(tensordict: T, pad_size: Sequence[int], value: float = 0.0) -> T:
|
||||
"""Pads all tensors in a tensordict along the batch dimensions with a constant value, returning a new tensordict.
|
||||
|
||||
Args:
|
||||
tensordict (TensorDict): The tensordict to pad
|
||||
pad_size (Sequence[int]): The padding size by which to pad some batch
|
||||
dimensions of the tensordict, starting from the first dimension and
|
||||
moving forward. [len(pad_size) / 2] dimensions of the batch size will
|
||||
be padded. For example to pad only the first dimension, pad has the form
|
||||
(padding_left, padding_right). To pad two dimensions,
|
||||
(padding_left, padding_right, padding_top, padding_bottom) and so on.
|
||||
pad_size must be even and less than or equal to twice the number of batch dimensions.
|
||||
value (float, optional): The fill value to pad by, default 0.0
|
||||
|
||||
Returns:
|
||||
A new TensorDict padded along the batch dimensions
|
||||
|
||||
Examples:
|
||||
>>> from torch.dict import TensorDict, pad
|
||||
>>> import torch
|
||||
>>> td = TensorDict({'a': torch.ones(3, 4, 1),
|
||||
... 'b': torch.ones(3, 4, 1, 1)}, batch_size=[3, 4])
|
||||
>>> dim0_left, dim0_right, dim1_left, dim1_right = [0, 1, 0, 2]
|
||||
>>> padded_td = pad(td, [dim0_left, dim0_right, dim1_left, dim1_right], value=0.0)
|
||||
>>> print(padded_td.batch_size)
|
||||
torch.Size([4, 6])
|
||||
>>> print(padded_td.get("a").shape)
|
||||
torch.Size([4, 6, 1])
|
||||
>>> print(padded_td.get("b").shape)
|
||||
torch.Size([4, 6, 1, 1])
|
||||
|
||||
"""
|
||||
if len(pad_size) > 2 * len(tensordict.batch_size):
|
||||
raise RuntimeError(
|
||||
"The length of pad_size must be <= 2 * the number of batch dimensions"
|
||||
)
|
||||
|
||||
if len(pad_size) % 2:
|
||||
raise RuntimeError("pad_size must have an even number of dimensions")
|
||||
|
||||
new_batch_size = list(tensordict.batch_size)
|
||||
for i in range(len(pad_size)):
|
||||
new_batch_size[i // 2] += pad_size[i]
|
||||
|
||||
reverse_pad = pad_size[::-1]
|
||||
for i in range(0, len(reverse_pad), 2):
|
||||
reverse_pad[i], reverse_pad[i + 1] = reverse_pad[i + 1], reverse_pad[i]
|
||||
|
||||
out = TensorDict(
|
||||
{}, torch.Size(new_batch_size), device=tensordict.device, _run_checks=False
|
||||
)
|
||||
for key, tensor in tensordict.items():
|
||||
cur_pad = reverse_pad
|
||||
if len(pad_size) < len(_shape(tensor)) * 2:
|
||||
cur_pad = [0] * (len(_shape(tensor)) * 2 - len(pad_size)) + reverse_pad
|
||||
|
||||
if _is_tensor_collection(tensor.__class__):
|
||||
padded = pad(tensor, pad_size, value)
|
||||
else:
|
||||
padded = torch.nn.functional.pad(tensor, cur_pad, value=value)
|
||||
out.set(key, padded)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def pad_sequence(
|
||||
list_of_tensordicts: Sequence[T],
|
||||
batch_first: bool = True,
|
||||
padding_value: float = 0.0,
|
||||
out: T | None = None,
|
||||
device: DeviceType | None = None,
|
||||
return_mask: bool | None = False,
|
||||
) -> T:
|
||||
"""Pads a list of tensordicts in order for them to be stacked together in a contiguous format.
|
||||
|
||||
Args:
|
||||
list_of_tensordicts (List[TensorDictBase]): the list of instances to pad and stack.
|
||||
batch_first (bool, optional): the ``batch_first`` correspondant of :func:`torch.nn.utils.rnn.pad_sequence`.
|
||||
Defaults to ``True``.
|
||||
padding_value (number, optional): the padding value. Defaults to ``0.0``.
|
||||
out (TensorDictBase, optional): if provided, the destination where the data will be
|
||||
written.
|
||||
device (device compatible type, optional): if provded, the device where the
|
||||
TensorDict output will be created.
|
||||
return_mask (bool, optional): if ``True``, a "mask" entry will be returned.
|
||||
It contains the mask of valid values in the stacked tensordict.
|
||||
|
||||
Examples:
|
||||
>>> list_td = [
|
||||
... TensorDict({"a": torch.zeros((3,))}, []),
|
||||
... TensorDict({"a": torch.zeros((4,))}, []),
|
||||
... ]
|
||||
>>> padded_td = pad_sequence(list_td)
|
||||
>>> print(padded_td)
|
||||
TensorDict(
|
||||
fields={
|
||||
a: Tensor(shape=torch.Size([2, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
|
||||
batch_size=torch.Size([]),
|
||||
device=None,
|
||||
is_shared=False)
|
||||
"""
|
||||
if not list_of_tensordicts:
|
||||
raise RuntimeError("list_of_tensordicts cannot be empty")
|
||||
# check that all tensordict match
|
||||
if return_mask:
|
||||
list_of_tensordicts = [
|
||||
td.clone(False).set("mask", torch.ones(td.shape, dtype=torch.bool))
|
||||
for td in list_of_tensordicts
|
||||
]
|
||||
keys = _check_keys(list_of_tensordicts, leaves_only=True, include_nested=True)
|
||||
shape = max(len(td) for td in list_of_tensordicts)
|
||||
if shape == 0:
|
||||
shape = [
|
||||
len(list_of_tensordicts),
|
||||
]
|
||||
elif batch_first:
|
||||
shape = [len(list_of_tensordicts), shape]
|
||||
else:
|
||||
shape = [shape, len(list_of_tensordicts)]
|
||||
if out is None:
|
||||
out = TensorDict(
|
||||
{}, batch_size=torch.Size(shape), device=device, _run_checks=False
|
||||
)
|
||||
for key in keys:
|
||||
try:
|
||||
out.set(
|
||||
key,
|
||||
torch.nn.utils.rnn.pad_sequence(
|
||||
[td.get(key) for td in list_of_tensordicts],
|
||||
batch_first=batch_first,
|
||||
padding_value=padding_value,
|
||||
),
|
||||
)
|
||||
except Exception as err:
|
||||
raise RuntimeError(f"pad_sequence failed for key {key}") from err
|
||||
return out
|
||||
else:
|
||||
for key in keys:
|
||||
out.set_(
|
||||
key,
|
||||
torch.nn.utils.rnn.pad_sequence(
|
||||
[td.get(key) for td in list_of_tensordicts],
|
||||
batch_first=batch_first,
|
||||
padding_value=padding_value,
|
||||
),
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def merge_tensordicts(*tensordicts: T) -> T:
|
||||
"""Merges tensordicts together."""
|
||||
if len(tensordicts) < 2:
|
||||
raise RuntimeError(
|
||||
f"at least 2 tensordicts must be provided, got" f" {len(tensordicts)}"
|
||||
)
|
||||
d = tensordicts[0].to_dict()
|
||||
batch_size = tensordicts[0].batch_size
|
||||
for td in tensordicts[1:]:
|
||||
d.update(td.to_dict())
|
||||
if td.batch_dims < len(batch_size):
|
||||
batch_size = td.batch_size
|
||||
return TensorDict(d, batch_size, device=td.device, _run_checks=False)
|
||||
|
||||
|
||||
def dense_stack_tds(
|
||||
td_list: Sequence[TensorDictBase] | LazyStackedTensorDict,
|
||||
dim: int = None,
|
||||
) -> T:
|
||||
"""Densely stack a list of :class:`~tensordict.TensorDictBase` objects (or a :class:`~tensordict.LazyStackedTensorDict`) given that they have the same structure.
|
||||
|
||||
This function is called with a list of :class:`~tensordict.TensorDictBase` (either passed directly or obtrained from
|
||||
a :class:`~tensordict.LazyStackedTensorDict`).
|
||||
Instead of calling ``torch.stack(td_list)``, which would return a :class:`~tensordict.LazyStackedTensorDict`,
|
||||
this function expands the first element of the input list and stacks the input list onto that element.
|
||||
This works only when all the elements of the input list have the same structure.
|
||||
The :class:`~tensordict.TensorDictBase` returned will have the same type of the elements of the input list.
|
||||
|
||||
This function is useful when some of the :class:`~tensordict.TensorDictBase` objects that need to be stacked
|
||||
are :class:`~tensordict.LazyStackedTensorDict` or have :class:`~tensordict.LazyStackedTensorDict`
|
||||
among entries (or nested entries).
|
||||
In those cases, calling ``torch.stack(td_list).to_tensordict()`` is infeasible.
|
||||
Thus, this function provides an alternative for densely stacking the list provided.
|
||||
|
||||
Args:
|
||||
td_list (List of TensorDictBase or LazyStackedTensorDict): the tds to stack.
|
||||
dim (int, optional): the dimension to stack them.
|
||||
If td_list is a LazyStackedTensorDict, it will be retrieved automatically.
|
||||
|
||||
Examples:
|
||||
>>> import torch
|
||||
>>> from tensordict import TensorDict
|
||||
>>> from tensordict import dense_stack_tds
|
||||
>>> from tensordict.tensordict import assert_allclose_td
|
||||
>>> td0 = TensorDict({"a": torch.zeros(3)},[])
|
||||
>>> td1 = TensorDict({"a": torch.zeros(4), "b": torch.zeros(2)},[])
|
||||
>>> td_lazy = torch.stack([td0, td1], dim=0)
|
||||
>>> td_container = TensorDict({"lazy": td_lazy}, [])
|
||||
>>> td_container_clone = td_container.clone()
|
||||
>>> td_stack = torch.stack([td_container, td_container_clone], dim=0)
|
||||
>>> td_stack
|
||||
LazyStackedTensorDict(
|
||||
fields={
|
||||
lazy: LazyStackedTensorDict(
|
||||
fields={
|
||||
a: Tensor(shape=torch.Size([2, 2, -1]), device=cpu, dtype=torch.float32, is_shared=False)},
|
||||
exclusive_fields={
|
||||
},
|
||||
batch_size=torch.Size([2, 2]),
|
||||
device=None,
|
||||
is_shared=False,
|
||||
stack_dim=0)},
|
||||
exclusive_fields={
|
||||
},
|
||||
batch_size=torch.Size([2]),
|
||||
device=None,
|
||||
is_shared=False,
|
||||
stack_dim=0)
|
||||
>>> td_stack = dense_stack_tds(td_stack) # Automatically use the LazyStackedTensorDict stack_dim
|
||||
TensorDict(
|
||||
fields={
|
||||
lazy: LazyStackedTensorDict(
|
||||
fields={
|
||||
a: Tensor(shape=torch.Size([2, 2, -1]), device=cpu, dtype=torch.float32, is_shared=False)},
|
||||
exclusive_fields={
|
||||
1 ->
|
||||
b: Tensor(shape=torch.Size([2, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
|
||||
batch_size=torch.Size([2, 2]),
|
||||
device=None,
|
||||
is_shared=False,
|
||||
stack_dim=1)},
|
||||
batch_size=torch.Size([2]),
|
||||
device=None,
|
||||
is_shared=False)
|
||||
# Note that
|
||||
# (1) td_stack is now a TensorDict
|
||||
# (2) this has pushed the stack_dim of "lazy" (0 -> 1)
|
||||
# (3) this has revealed the exclusive keys.
|
||||
>>> assert_allclose_td(td_stack, dense_stack_tds([td_container, td_container_clone], dim=0))
|
||||
# This shows it is the same to pass a list or a LazyStackedTensorDict
|
||||
|
||||
"""
|
||||
if isinstance(td_list, LazyStackedTensorDict):
|
||||
dim = td_list.stack_dim
|
||||
td_list = td_list.tensordicts
|
||||
elif dim is None:
|
||||
raise ValueError(
|
||||
"If a list of tensordicts is provided, stack_dim must not be None"
|
||||
)
|
||||
shape = list(td_list[0].shape)
|
||||
shape.insert(dim, len(td_list))
|
||||
|
||||
result = td_list[0].unsqueeze(dim)
|
||||
result = result.expand(shape)
|
||||
result = result.clone()
|
||||
return LazyStackedTensorDict.maybe_dense_stack(td_list, dim=dim, out=result)
|
||||
|
||||
|
||||
def make_tensordict(
|
||||
input_dict: dict[str, CompatibleType] | None = None,
|
||||
batch_size: Sequence[int] | torch.Size | int | None = None,
|
||||
device: DeviceType | None = None,
|
||||
**kwargs: CompatibleType, # source
|
||||
) -> TensorDict:
|
||||
"""Returns a TensorDict created from the keyword arguments or an input dictionary.
|
||||
|
||||
If ``batch_size`` is not specified, returns the maximum batch size possible.
|
||||
|
||||
This function works on nested dictionaries too, or can be used to determine the
|
||||
batch-size of a nested tensordict.
|
||||
|
||||
Args:
|
||||
input_dict (dictionary, optional): a dictionary to use as a data source
|
||||
(nested keys compatible).
|
||||
**kwargs (TensorDict or torch.Tensor): keyword arguments as data source
|
||||
(incompatible with nested keys).
|
||||
batch_size (iterable of int, optional): a batch size for the tensordict.
|
||||
device (torch.device or compatible type, optional): a device for the TensorDict.
|
||||
|
||||
Examples:
|
||||
>>> input_dict = {"a": torch.randn(3, 4), "b": torch.randn(3)}
|
||||
>>> print(make_tensordict(input_dict))
|
||||
TensorDict(
|
||||
fields={
|
||||
a: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
|
||||
b: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
|
||||
batch_size=torch.Size([3]),
|
||||
device=None,
|
||||
is_shared=False)
|
||||
>>> # alternatively
|
||||
>>> td = make_tensordict(**input_dict)
|
||||
>>> # nested dict: the nested TensorDict can have a different batch-size
|
||||
>>> # as long as its leading dims match.
|
||||
>>> input_dict = {"a": torch.randn(3), "b": {"c": torch.randn(3, 4)}}
|
||||
>>> print(make_tensordict(input_dict))
|
||||
TensorDict(
|
||||
fields={
|
||||
a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
|
||||
b: TensorDict(
|
||||
fields={
|
||||
c: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
|
||||
batch_size=torch.Size([3, 4]),
|
||||
device=None,
|
||||
is_shared=False)},
|
||||
batch_size=torch.Size([3]),
|
||||
device=None,
|
||||
is_shared=False)
|
||||
>>> # we can also use this to work out the batch sie of a tensordict
|
||||
>>> input_td = TensorDict({"a": torch.randn(3), "b": {"c": torch.randn(3, 4)}}, [])
|
||||
>>> print(make_tensordict(input_td))
|
||||
TensorDict(
|
||||
fields={
|
||||
a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
|
||||
b: TensorDict(
|
||||
fields={
|
||||
c: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
|
||||
batch_size=torch.Size([3, 4]),
|
||||
device=None,
|
||||
is_shared=False)},
|
||||
batch_size=torch.Size([3]),
|
||||
device=None,
|
||||
is_shared=False)
|
||||
"""
|
||||
if input_dict is not None:
|
||||
kwargs.update(input_dict)
|
||||
return TensorDict.from_dict(kwargs, batch_size=batch_size, device=device)
|
||||
1174
torch/dict/params.py
Normal file
1174
torch/dict/params.py
Normal file
File diff suppressed because it is too large
Load Diff
1361
torch/dict/tensorclass.py
Normal file
1361
torch/dict/tensorclass.py
Normal file
File diff suppressed because it is too large
Load Diff
2813
torch/dict/tensordict.py
Normal file
2813
torch/dict/tensordict.py
Normal file
File diff suppressed because it is too large
Load Diff
1526
torch/dict/utils.py
Normal file
1526
torch/dict/utils.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -89,57 +89,79 @@ def _untie_named_tensors_map(
|
||||
@contextlib.contextmanager
|
||||
def _reparametrize_module(
|
||||
module: "torch.nn.Module",
|
||||
parameters_and_buffers: Dict[str, Tensor],
|
||||
parameters_and_buffers: Union[Dict[str, Tensor], "TensorDictBase"],
|
||||
*,
|
||||
tie_weights: bool = False,
|
||||
strict: bool = False,
|
||||
) -> Iterator[None]:
|
||||
if tie_weights:
|
||||
untied_parameters_and_buffers = _untie_named_tensors_map(
|
||||
module, parameters_and_buffers
|
||||
)
|
||||
from torch.dict import TensorDictBase
|
||||
|
||||
if isinstance(parameters_and_buffers, TensorDictBase):
|
||||
if strict:
|
||||
raise NotImplementedError
|
||||
if not tie_weights:
|
||||
raise NotImplementedError
|
||||
orig_parameters_and_buffers = parameters_and_buffers.empty()
|
||||
try:
|
||||
orig_parameters_and_buffers = parameters_and_buffers.to_module(
|
||||
module, return_swap=True
|
||||
)
|
||||
yield
|
||||
finally:
|
||||
# tensordict is locked by default in this case, so we unlock it as we can't tell if an inplace update
|
||||
# can be done (most likely not)
|
||||
orig_parameters_and_buffers.to_module(
|
||||
module, return_swap=True, swap_dest=parameters_and_buffers
|
||||
)
|
||||
else:
|
||||
untied_parameters_and_buffers = parameters_and_buffers
|
||||
|
||||
accessor = NamedMemberAccessor(module)
|
||||
if strict:
|
||||
missing_keys, unexpected_keys = accessor.check_keys(
|
||||
untied_parameters_and_buffers
|
||||
)
|
||||
error_msgs = []
|
||||
if len(unexpected_keys) > 0:
|
||||
error_msgs.append(
|
||||
f"Unexpected key(s): {', '.join(map(repr, unexpected_keys))}."
|
||||
if tie_weights:
|
||||
untied_parameters_and_buffers = _untie_named_tensors_map(
|
||||
module, parameters_and_buffers
|
||||
)
|
||||
if len(missing_keys) > 0:
|
||||
error_msgs.append(f"Missing key(s): {', '.join(map(repr, missing_keys))}.")
|
||||
if len(error_msgs) > 0:
|
||||
raise RuntimeError(
|
||||
"Error(s) in reparametrizing for {}:\n\t{}".format(
|
||||
module._get_name(), "\n\t".join(error_msgs)
|
||||
else:
|
||||
untied_parameters_and_buffers = parameters_and_buffers
|
||||
|
||||
accessor = NamedMemberAccessor(module)
|
||||
if strict:
|
||||
missing_keys, unexpected_keys = accessor.check_keys(
|
||||
untied_parameters_and_buffers
|
||||
)
|
||||
error_msgs = []
|
||||
if len(unexpected_keys) > 0:
|
||||
error_msgs.append(
|
||||
f"Unexpected key(s): {', '.join(map(repr, unexpected_keys))}."
|
||||
)
|
||||
if len(missing_keys) > 0:
|
||||
error_msgs.append(
|
||||
f"Missing key(s): {', '.join(map(repr, missing_keys))}."
|
||||
)
|
||||
if len(error_msgs) > 0:
|
||||
raise RuntimeError(
|
||||
"Error(s) in reparametrizing for {}:\n\t{}".format(
|
||||
module._get_name(), "\n\t".join(error_msgs)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
orig_parameters_and_buffers: Dict[str, Tensor] = {}
|
||||
try:
|
||||
orig_parameters_and_buffers, _ = accessor.swap_tensors_dict(
|
||||
untied_parameters_and_buffers, allow_missing=True
|
||||
)
|
||||
yield
|
||||
finally:
|
||||
new_parameters_and_buffers, _ = accessor.swap_tensors_dict(
|
||||
orig_parameters_and_buffers, allow_missing=True
|
||||
)
|
||||
# Sometimes the module is not completely stateless and has some in-place modifications on
|
||||
# the _parameters and _buffers dictionaries.
|
||||
# Write the changed parameters and buffers back to the original dict.
|
||||
parameters_and_buffers.update(
|
||||
{
|
||||
k: new_parameters_and_buffers[k]
|
||||
for k in parameters_and_buffers
|
||||
if k in new_parameters_and_buffers
|
||||
}
|
||||
)
|
||||
orig_parameters_and_buffers: Dict[str, Tensor] = {}
|
||||
try:
|
||||
orig_parameters_and_buffers, _ = accessor.swap_tensors_dict(
|
||||
untied_parameters_and_buffers, allow_missing=True
|
||||
)
|
||||
yield
|
||||
finally:
|
||||
new_parameters_and_buffers, _ = accessor.swap_tensors_dict(
|
||||
orig_parameters_and_buffers, allow_missing=True
|
||||
)
|
||||
# Sometimes the module is not completely stateless and has some in-place modifications on
|
||||
# the _parameters and _buffers dictionaries.
|
||||
# Write the changed parameters and buffers back to the original dict.
|
||||
parameters_and_buffers.update(
|
||||
{
|
||||
k: new_parameters_and_buffers[k]
|
||||
for k in parameters_and_buffers
|
||||
if k in new_parameters_and_buffers
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def functional_call(
|
||||
@ -228,7 +250,7 @@ def functional_call(
|
||||
|
||||
def _functional_call(
|
||||
module: "torch.nn.Module",
|
||||
parameters_and_buffers: Dict[str, Tensor],
|
||||
parameters_and_buffers: Union[Dict[str, Tensor], "TensorDictBase"],
|
||||
args: Union[Any, Tuple],
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
*,
|
||||
|
||||
Reference in New Issue
Block a user