Compare commits

...

24 Commits

Author SHA1 Message Date
ba96cfd2f7 amend 2024-01-11 17:33:10 +00:00
7af4f58380 amend 2024-01-10 18:00:11 +00:00
4352ebe8bc amend 2024-01-10 12:12:42 +00:00
d52fcc6d33 Merge branch 'main' into tensordict_integration 2024-01-10 11:12:49 +00:00
cd0e9c4c05 fix 2023-11-16 17:48:57 +00:00
8fc1309f9f edits 2023-11-16 17:08:29 +00:00
11a52e8d6d Merge remote-tracking branch 'origin/main' into tensordict_integration 2023-11-10 10:57:10 -05:00
52061a05a4 functional decorator 2023-11-08 21:04:47 -05:00
afe4a40805 functional decorator 2023-11-08 20:56:32 -05:00
f5c809e33d functional efficiency 2023-11-06 17:27:46 -05:00
1d9582c627 fixes 2023-11-06 16:49:46 -05:00
2331b048af no more pickling 2023-11-06 16:09:47 -05:00
fb2103bee1 faster from_module 2023-11-06 11:52:54 -05:00
2a69d65d52 lint 2023-11-05 18:22:58 -05:00
f36bd08109 native shared tensors 2023-11-05 18:17:27 -05:00
d6160943b1 partial fix 2023-11-02 21:18:50 +00:00
d147161883 partial fix 2023-11-02 18:10:19 +00:00
6d3b90b64b more tests 2023-11-01 17:23:38 +00:00
d040d35294 fix tensorclass tests 2023-11-01 13:33:55 +00:00
cf6704548a fixes 2023-11-01 12:18:49 +00:00
b8ab16bad6 tensorclass 2023-11-01 11:46:15 +00:00
570605e37d amend 2023-11-01 10:25:06 +00:00
d307e5e0be Merge remote-tracking branch 'origin/main' into tensordict_integration 2023-10-30 19:40:19 +00:00
19f3d13102 init 2023-10-30 19:38:04 +00:00
18 changed files with 21356 additions and 75 deletions

4473
test/test_dict.py Normal file

File diff suppressed because it is too large Load Diff

1637
test/test_tensorclass.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -1262,7 +1262,7 @@ def grad_and_value(func: Callable, argnums: argnums_t = 0, has_aux: bool = False
)
output, aux = output
if not isinstance(output, torch.Tensor):
if not isinstance(output, (torch.Tensor, torch.dict.TensorDictBase)):
raise RuntimeError('grad_and_value(f)(*args): Expected f(*args) '
f'to return a Tensor, got {type(output)}')
if output.dim() != 0:

View File

@ -5,12 +5,15 @@ import torch
import torch.nn as nn
from torch import Tensor
from torch._functorch.utils import exposed_in
from torch.dict import TensorDictBase
@exposed_in("torch.func")
def functional_call(
module: "torch.nn.Module",
parameter_and_buffer_dicts: Union[Dict[str, Tensor], Sequence[Dict[str, Tensor]]],
parameter_and_buffer_dicts: Union[
Dict[str, Tensor], Sequence[Dict[str, Tensor]], TensorDictBase
],
args: Union[Any, Tuple],
kwargs: Optional[Dict[str, Any]] = None,
*,
@ -118,7 +121,7 @@ def functional_call(
Returns:
Any: the result of calling ``module``.
"""
if isinstance(parameter_and_buffer_dicts, dict):
if isinstance(parameter_and_buffer_dicts, (dict, TensorDictBase)):
parameters_and_buffers = parameter_and_buffer_dicts
elif isinstance(parameter_and_buffer_dicts, Sequence):
if not all(isinstance(d, dict) for d in parameter_and_buffer_dicts):

View File

@ -9,12 +9,13 @@ import functools
import threading
from torch import Tensor
from typing import Any, Callable, Optional, Tuple, Union, List
from torch.utils._pytree import (
tree_flatten,
tree_unflatten,
tree_map_,
_broadcast_to_and_flatten,
TreeSpec,
TreeSpec, SUPPORTED_NODES,
)
from functools import partial
import os
@ -28,9 +29,24 @@ from torch._C._functorch import (
is_batchedtensor,
)
in_dims_t = Union[int, Tuple]
out_dims_t = Union[int, Tuple[int, ...]]
class _exclude_td_from_pytree:
def __init__(self):
from torch.dict._pytree import PYTREE_REGISTERED_TDS
self.PYTREE_REGISTERED_TDS = PYTREE_REGISTERED_TDS
self.tdnodes = {}
def __enter__(self):
for tdtype in self.PYTREE_REGISTERED_TDS:
self.tdnodes[tdtype] = SUPPORTED_NODES.pop(tdtype)
def __exit__(self, exc_type, exc_val, exc_tb):
for tdtype in self.PYTREE_REGISTERED_TDS:
SUPPORTED_NODES[tdtype] = self.tdnodes[tdtype]
def doesnt_support_saved_tensors_hooks(f):
message = (
@ -80,6 +96,7 @@ def _as_tuple(value: Any, num_elements: int, error_message_lambda: Callable[[],
def _process_batched_inputs(
in_dims: in_dims_t, args: Tuple, func: Callable
) -> Tuple[int, List[Any], List[Any], TreeSpec]:
if not isinstance(in_dims, int) and not isinstance(in_dims, tuple):
raise ValueError(
f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '
@ -91,14 +108,17 @@ def _process_batched_inputs(
f'inputs, or you are trying to vmap over a function with no inputs. '
f'The latter is unsupported.')
flat_args, args_spec = tree_flatten(args)
flat_in_dims = _broadcast_to_and_flatten(in_dims, args_spec)
if flat_in_dims is None:
raise ValueError(
f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '
f'in_dims is not compatible with the structure of `inputs`. '
f'in_dims has structure {tree_flatten(in_dims)[1]} but inputs '
f'has structure {args_spec}.')
# we want to escape TensorDicts as they take care of adding the batch dimension
with _exclude_td_from_pytree():
flat_args, args_spec = tree_flatten(args)
flat_in_dims = _broadcast_to_and_flatten(in_dims, args_spec)
if flat_in_dims is None:
raise ValueError(
f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '
f'in_dims is not compatible with the structure of `inputs`. '
f'in_dims has structure {tree_flatten(in_dims)[1]} but inputs '
f'has structure {args_spec}.'
)
for i, (arg, in_dim) in enumerate(zip(flat_args, flat_in_dims)):
if not isinstance(in_dim, int) and in_dim is not None:
@ -106,7 +126,10 @@ def _process_batched_inputs(
f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '
f'Got in_dim={in_dim} for an input but in_dim must be either '
f'an integer dimension or None.')
if isinstance(in_dim, int) and not isinstance(arg, Tensor):
from torch.dict.base import TensorDictBase
if isinstance(in_dim, int) and not isinstance(
arg, (Tensor, TensorDictBase)
):
raise ValueError(
f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '
f'Got in_dim={in_dim} for an input but the input is of type '
@ -130,10 +153,29 @@ def _process_batched_inputs(
def _create_batched_inputs(
flat_in_dims: List[Any], flat_args: List[Any], vmap_level: int, args_spec) -> Tuple:
# See NOTE [Ignored _remove_batch_dim, _add_batch_dim]
batched_inputs = [arg if in_dim is None else
_add_batch_dim(arg, in_dim, vmap_level)
for in_dim, arg in zip(flat_in_dims, flat_args)]
return tree_unflatten(batched_inputs, args_spec)
# If tensordict, we remove the dim at batch_size[in_dim] such that the TensorDict can accept
# the batched tensors. This will be added in _unwrap_batched
from torch.dict.base import TensorDictBase
batched_inputs = []
for in_dim, arg in zip(flat_in_dims, flat_args):
if in_dim is None:
if isinstance(arg, TensorDictBase):
# this may be a perf bottleneck and could benefit from caching
# arg = cache(arg.clone)(False)
arg = arg.clone(False)
batched_input = arg
else:
if isinstance(arg, TensorDictBase):
batched_input = arg._add_batch_dim(
in_dim=in_dim, vmap_level=vmap_level
)
else:
batched_input = _add_batch_dim(arg, in_dim, vmap_level)
batched_inputs.append(batched_input)
with _exclude_td_from_pytree():
return tree_unflatten(batched_inputs, args_spec)
def _maybe_remove_batch_dim(name, batched_output, vmap_level, batch_size, out_dim):
@ -160,36 +202,58 @@ def _unwrap_batched(
batched_outputs: Union[Tensor, Tuple[Tensor, ...]],
out_dims: out_dims_t,
vmap_level: int, batch_size: int, func: Callable) -> Tuple:
flat_batched_outputs, output_spec = tree_flatten(batched_outputs)
from torch.dict.base import TensorDictBase
with _exclude_td_from_pytree():
flat_batched_outputs, output_spec = tree_flatten(batched_outputs)
for out in flat_batched_outputs:
# Change here:
if isinstance(out, (TensorDictBase, torch.Tensor)):
continue
raise ValueError(
f"vmap({_get_name(func)}, ...): `{_get_name(func)}` must only return "
f"Tensors, got type {type(out)} as a return."
)
def incompatible_error():
raise ValueError(
f'vmap({_get_name(func)}, ..., out_dims={out_dims})(<inputs>): '
f'out_dims is not compatible with the structure of `outputs`. '
f'out_dims has structure {tree_flatten(out_dims)[1]} but outputs '
f'has structure {output_spec}.')
f"vmap({_get_name(func)}, ..., out_dims={out_dims})(<inputs>): "
f"out_dims is not compatible with the structure of `outputs`. "
f"out_dims has structure {tree_flatten(out_dims)[1]} but outputs "
f"has structure {output_spec}."
)
if isinstance(batched_outputs, torch.Tensor):
# Here:
if isinstance(batched_outputs, (TensorDictBase, torch.Tensor)):
# Some weird edge case requires us to spell out the following
# see test_out_dims_edge_case
if isinstance(out_dims, int):
flat_out_dims = [out_dims]
elif isinstance(out_dims, tuple) and len(out_dims) == 1:
flat_out_dims = out_dims
elif out_dims is None:
flat_out_dims = [out_dims]
out_dims = out_dims[0]
else:
incompatible_error()
else:
flat_out_dims = _broadcast_to_and_flatten(out_dims, output_spec)
if flat_out_dims is None:
incompatible_error()
flat_outputs = [
_maybe_remove_batch_dim(_get_name(func), batched_output, vmap_level, batch_size, out_dim)
for batched_output, out_dim in zip(flat_batched_outputs, flat_out_dims)
]
return tree_unflatten(flat_outputs, output_spec)
flat_outputs = []
for batched_output, out_dim in zip(flat_batched_outputs, flat_out_dims):
if not isinstance(batched_output, TensorDictBase):
out = _remove_batch_dim(
batched_output,
vmap_level,
batch_size,
out_dim
)
else:
out = batched_output._remove_batch_dim(
vmap_level=vmap_level, batch_size=batch_size, out_dim=out_dim
)
flat_outputs.append(out)
with _exclude_td_from_pytree():
return tree_unflatten(flat_outputs, output_spec)
def _check_int_or_none(x, func, out_dims):

7
torch/dict/__init__.py Normal file
View File

@ -0,0 +1,7 @@
from .base import TensorDictBase
from .functional import dense_stack_tds, merge_tensordicts, pad, pad_sequence
from .params import TensorDictParams
from .tensorclass import tensorclass
from .tensordict import TensorDict
from ._pytree import *
from ._lazy import LazyStackedTensorDict

0
torch/dict/_keys.py Normal file
View File

2151
torch/dict/_lazy.py Normal file

File diff suppressed because it is too large Load Diff

704
torch/dict/_memmap.py Normal file
View File

@ -0,0 +1,704 @@
from __future__ import annotations
import mmap
import os
import sys
import tempfile
from multiprocessing import util
from multiprocessing.context import reduction
from pathlib import Path
from typing import Any, overload
import numpy as np
import torch
from torch import distributed as dist
from torch.multiprocessing.reductions import ForkingPickler
from .utils import implement_for
try:
if dist.is_available():
from torch.distributed._tensor.api import DTensor
else:
raise ImportError
except ImportError:
class DTensor(torch.Tensor): # noqa: D101
...
class MemoryMappedTensor(torch.Tensor):
"""A Memory-mapped Tensor.
Supports filenames or file handlers.
The main advantage of MemoryMappedTensor resides in its serialization methods,
which ensure that the tensor is passed through queues or RPC remote calls without
any copy.
.. note::
When used within RPC settings, the filepath should be accessible to both nodes.
If it isn't the behaviour of passing a MemoryMappedTensor from one worker
to another is undefined.
MemoryMappedTensor supports multiple construction methods.
Examples:
>>> # from an existing tensor
>>> tensor = torch.randn(3)
>>> with tempfile.NamedTemporaryFile() as file:
... memmap_tensor = MemoryMappedTensor.from_tensor(tensor, filename=file.name)
... assert memmap_tensor.filename is not None
>>> # if no filename is passed, a handler is used
>>> tensor = torch.randn(3)
>>> memmap_tensor = MemoryMappedTensor.from_tensor(tensor, filename=file.name)
>>> assert memmap_tensor.filename is None
>>> # one can create an empty tensor too
>>> with tempfile.NamedTemporaryFile() as file:
... memmap_tensor_empty = MemoryMappedTensor.empty_like(tensor, filename=file.name)
>>> with tempfile.NamedTemporaryFile() as file:
... memmap_tensor_zero = MemoryMappedTensor.zeros_like(tensor, filename=file.name)
>>> with tempfile.NamedTemporaryFile() as file:
... memmap_tensor = MemoryMappedTensor.ones_like(tensor, filename=file.name)
"""
_filename: str | Path
_handler: _FileHandler
_clear: bool
index: Any
parent_shape: torch.Size
def __new__(
cls,
tensor_or_file,
*,
dtype=None,
shape=None,
index=None,
device=None,
handler=None,
):
if device is not None and torch.device(device).type != "cpu":
raise ValueError(f"{cls} device must be cpu!")
if isinstance(tensor_or_file, str):
return cls.from_filename(
tensor_or_file,
dtype,
shape,
index,
)
elif handler is not None:
return cls.from_handler(
handler,
dtype,
shape,
index,
)
return super().__new__(cls, tensor_or_file)
def __init__(
self, tensor_or_file, handler=None, dtype=None, shape=None, device=None
):
...
__torch_function__ = torch._C._disabled_torch_function_impl
@classmethod
def from_tensor(
cls,
input,
*,
filename=None,
existsok=False,
copy_existing=False,
copy_data=True,
):
"""Creates a MemoryMappedTensor with the same content as another tensor.
If the tensor is already a MemoryMappedTensor the original tensor is
returned if the `filename` argument is `None` or if the two paths match.
In all other cases, a new :class:`MemoryMappedTensor` is produced.
Args:
input (torch.Tensor): the tensor which content must be copied onto
the MemoryMappedTensor.
filename (path to a file): the path to the file where the tensor
should be stored. If none is provided, a file handler is used
instead.
existsok (bool, optional): if ``True``, the file will overwrite
an existing file. Defaults to ``False``.
copy_existing (bool, optional): if ``True`` and the provided input
is a MemoryMappedTensor with an associated filename, copying
the content to the new location is permitted. Otherwise an
exception is thown. This behaviour exists to prevent
unadvertedly duplicating data on disk.
copy_data (bool, optional): if ``True``, the content of the tensor
will be copied on the storage. Defaults to ``True``.
"""
if isinstance(input, MemoryMappedTensor):
if (filename is None and input._filename is None) or (
input._filename is not None
and filename is not None
and Path(filename).absolute() == Path(input.filename).absolute()
):
# either location was not specified, or memmap is already in the
# correct location, so just return the MemmapTensor unmodified
return input
elif not copy_existing and (
input._filename is not None
and filename is not None
and Path(filename).absolute() != Path(input.filename).absolute()
):
raise RuntimeError(
f"A filename was provided but the tensor already has a file associated "
f"({input.filename}). "
f"To copy the tensor onto the new location, pass copy_existing=True."
)
elif isinstance(input, np.ndarray):
raise TypeError(
"Convert input to torch.Tensor before calling MemoryMappedTensor.from_tensor."
)
if input.requires_grad:
raise RuntimeError(
"MemoryMappedTensor.from_tensor is incompatible with tensor.requires_grad."
)
shape = input.shape
if filename is None:
if input.dtype.is_floating_point:
size = torch.finfo(input.dtype).bits // 8 * shape.numel()
elif input.dtype.is_complex:
raise ValueError(
"Complex-valued tensors are not supported by MemoryMappedTensor."
)
elif input.dtype == torch.bool:
size = shape.numel()
else:
# assume integer
size = torch.iinfo(input.dtype).bits // 8 * shape.numel()
handler = _FileHandler(size)
out = torch.frombuffer(memoryview(handler.buffer), dtype=input.dtype)
out = out.view(shape)
out = cls(out)
else:
handler = None
if not existsok and os.path.exists(str(filename)):
raise RuntimeError(f"The file {filename} already exists.")
out = cls(
torch.from_file(
str(filename), shared=True, dtype=input.dtype, size=shape.numel()
).view(input.shape)
)
out._handler = handler
out._filename = filename
out.index = None
out.parent_shape = input.shape
if copy_data:
if isinstance(input, DTensor):
input = input.full_tensor()
out.copy_(input)
return out
@property
def filename(self):
"""The filename of the tensor, if it has one.
Raises an exception otherwise.
"""
filename = self._filename
if filename is None:
raise RuntimeError("The MemoryMappedTensor has no file associated.")
return filename
@classmethod
def empty_like(cls, input, *, filename=None):
# noqa: D417
"""Creates a tensor with no content but the same shape and dtype as the input tensor.
Args:
input (torch.Tensor): the tensor to use as an example.
Keyword Args:
filename (path or equivalent): the path to the file, if any. If none
is provided, a handler is used.
"""
return cls.from_tensor(
torch.zeros((), dtype=input.dtype, device=input.device).expand_as(input),
filename=filename,
copy_data=False,
)
@classmethod
def full_like(cls, input, fill_value, *, filename=None):
# noqa: D417
"""Creates a tensor with a single content indicated by the `fill_value` argument, but the same shape and dtype as the input tensor.
Args:
input (torch.Tensor): the tensor to use as an example.
fill_value (float or equivalent): content of the tensor.
Keyword Args:
filename (path or equivalent): the path to the file, if any. If none
is provided, a handler is used.
"""
return cls.from_tensor(
torch.zeros((), dtype=input.dtype, device=input.device).expand_as(input),
filename=filename,
copy_data=False,
).fill_(fill_value)
@classmethod
def zeros_like(cls, input, *, filename=None):
# noqa: D417
"""Creates a tensor with a 0-filled content, but the same shape and dtype as the input tensor.
Args:
input (torch.Tensor): the tensor to use as an example.
Keyword Args:
filename (path or equivalent): the path to the file, if any. If none
is provided, a handler is used.
"""
return cls.from_tensor(
torch.zeros((), dtype=input.dtype, device=input.device).expand_as(input),
filename=filename,
copy_data=False,
).fill_(0.0)
@classmethod
def ones_like(cls, input, *, filename=None):
# noqa: D417
"""Creates a tensor with a 1-filled content, but the same shape and dtype as the input tensor.
Args:
input (torch.Tensor): the tensor to use as an example.
Keyword Args:
filename (path or equivalent): the path to the file, if any. If none
is provided, a handler is used.
"""
return cls.from_tensor(
torch.ones((), dtype=input.dtype, device=input.device).expand_as(input),
filename=filename,
copy_data=False,
).fill_(1.0)
@classmethod
@overload
def ones(cls, *size, dtype=None, device=None, filename=None):
...
@classmethod
@overload
def ones(cls, shape, *, dtype=None, device=None, filename=None):
...
@classmethod
def ones(cls, *args, **kwargs):
# noqa: D417
"""Creates a tensor with a 1-filled content, specific shape, dtype and filename.
Args:
shape (integers or torch.Size): the shape of the tensor.
Keyword Args:
dtype (torch.dtype): the dtype of the tensor.
device (torch.device): the device of the tensor. Only `None` and `"cpu"`
are accepted, any other device will raise an exception.
filename (path or equivalent): the path to the file, if any. If none
is provided, a handler is used.
"""
shape, device, dtype, _, filename = _proc_args_const(*args, **kwargs)
if device is not None:
device = torch.device(device)
if device.type != "cpu":
raise RuntimeError("Only CPU tensors are supported.")
result = torch.ones((), dtype=dtype, device=device)
if shape:
if isinstance(shape[0], (list, tuple)) and len(shape) == 1:
shape = torch.Size(shape[0])
else:
shape = torch.Size(shape)
result = result.expand(shape)
return cls.from_tensor(
result,
filename=filename,
)
@classmethod
@overload
def zeros(cls, *size, dtype=None, device=None, filename=None):
...
@classmethod
@overload
def zeros(cls, shape, *, dtype=None, device=None, filename=None):
...
@classmethod
def zeros(cls, *args, **kwargs):
# noqa: D417
"""Creates a tensor with a 0-filled content, specific shape, dtype and filename.
Args:
shape (integers or torch.Size): the shape of the tensor.
Keyword Args:
dtype (torch.dtype): the dtype of the tensor.
device (torch.device): the device of the tensor. Only `None` and `"cpu"`
are accepted, any other device will raise an exception.
filename (path or equivalent): the path to the file, if any. If none
is provided, a handler is used.
"""
shape, device, dtype, _, filename = _proc_args_const(*args, **kwargs)
if device is not None:
device = torch.device(device)
if device.type != "cpu":
raise RuntimeError("Only CPU tensors are supported.")
result = torch.zeros((), dtype=dtype, device=device)
if shape:
if isinstance(shape[0], (list, tuple)) and len(shape) == 1:
shape = torch.Size(shape[0])
else:
shape = torch.Size(shape)
result = result.expand(shape)
result = cls.from_tensor(
result,
filename=filename,
)
return result
@classmethod
@overload
def empty(cls, *size, dtype=None, device=None, filename=None):
...
@classmethod
@overload
def empty(cls, shape, *, dtype=None, device=None, filename=None):
...
@classmethod
def empty(cls, *args, **kwargs):
# noqa: D417
"""Creates a tensor with empty content, specific shape, dtype and filename.
Args:
shape (integers or torch.Size): the shape of the tensor.
Keyword Args:
dtype (torch.dtype): the dtype of the tensor.
device (torch.device): the device of the tensor. Only `None` and `"cpu"`
are accepted, any other device will raise an exception.
filename (path or equivalent): the path to the file, if any. If none
is provided, a handler is used.
"""
shape, device, dtype, _, filename = _proc_args_const(*args, **kwargs)
if device is not None:
device = torch.device(device)
if device.type != "cpu":
raise RuntimeError("Only CPU tensors are supported.")
result = torch.zeros((), dtype=dtype, device=device)
if shape:
if isinstance(shape[0], (list, tuple)) and len(shape) == 1:
shape = torch.Size(shape[0])
else:
shape = torch.Size(shape)
result = result.expand(shape)
result = cls.from_tensor(result, filename=filename)
return result
@classmethod
@overload
def full(cls, *size, fill_value, dtype=None, device=None, filename=None):
...
@classmethod
@overload
def full(cls, shape, *, fill_value, dtype=None, device=None, filename=None):
...
@classmethod
def full(cls, *args, **kwargs):
# noqa: D417
"""Creates a tensor with a single content specified by `fill_value`, specific shape, dtype and filename.
Args:
shape (integers or torch.Size): the shape of the tensor.
Keyword Args:
fill_value (float or equivalent): content of the tensor.
dtype (torch.dtype): the dtype of the tensor.
device (torch.device): the device of the tensor. Only `None` and `"cpu"`
are accepted, any other device will raise an exception.
filename (path or equivalent): the path to the file, if any. If none
is provided, a handler is used.
"""
shape, device, dtype, fill_value, filename = _proc_args_const(*args, **kwargs)
if device is not None:
device = torch.device(device)
if device.type != "cpu":
raise RuntimeError("Only CPU tensors are supported.")
result = torch.zeros((), dtype=dtype, device=device).fill_(fill_value)
if shape:
if isinstance(shape[0], (list, tuple)) and len(shape) == 1:
shape = torch.Size(shape[0])
else:
shape = torch.Size(shape)
result = result.expand(shape)
return cls.from_tensor(result, filename=filename)
@classmethod
def from_filename(cls, filename, dtype, shape, index=None):
# noqa: D417
"""Loads a MemoryMappedTensor from a given filename.
Args:
filename (path or equivalent): the path to the file.
dtype (torch.dtype): the dtype of the tensor.
shape (integers or torch.Size): the shape of the tensor.
index (torch-compatible index type): an index to use to build the
tensor.
"""
shape = torch.Size(shape)
tensor = torch.from_file(
str(filename), shared=True, dtype=dtype, size=shape.numel()
).view(shape)
if index is not None:
tensor = tensor[index]
out = cls(tensor)
out._filename = filename
out._handler = None
out.index = index
out.parent_shape = shape
return out
@classmethod
def from_handler(cls, handler, dtype, shape, index):
# noqa: D417
"""Loads a MemoryMappedTensor from a given handler.
Args:
handler (compatible file handler): the handler for the tensor.
dtype (torch.dtype): the dtype of the tensor.
shape (integers or torch.Size): the shape of the tensor.
index (torch-compatible index type): an index to use to build the
tensor.
"""
shape = torch.Size(shape)
out = torch.frombuffer(memoryview(handler.buffer), dtype=dtype)
out = torch.reshape(out, shape)
if index is not None:
out = out[index]
out = cls(out)
out._filename = None
out._handler = handler
out.index = index
out.parent_shape = shape
return out
@property
def _tensor(self):
# for bc-compatibility with MemmapTensor, to be deprecated in v0.4
return self
def __setstate__(self, state):
if "filename" in state:
self.__dict__ = type(self).from_filename(**state).__dict__
else:
self.__dict__ = type(self).from_handler(**state).__dict__
def __getstate__(self):
if getattr(self, "_handler", None) is not None:
return {
"handler": self._handler,
"dtype": self.dtype,
"shape": self.parent_shape,
"index": self.index,
}
elif getattr(self, "_filename", None) is not None:
return {
"filename": self._filename,
"dtype": self.dtype,
"shape": self.parent_shape,
"index": self.index,
}
else:
raise RuntimeError("Could not find handler or filename.")
def __reduce_ex__(self, protocol):
return self.__reduce__()
def __reduce__(self):
if getattr(self, "_handler", None) is not None:
return type(self).from_handler, (
self._handler,
self.dtype,
self.parent_shape,
self.index,
)
elif getattr(self, "_filename", None) is not None:
return type(self).from_filename, (
self._filename,
self.dtype,
self.parent_shape,
self.index,
)
else:
raise RuntimeError("Could not find handler or filename.")
@implement_for("torch", "2.0", None)
def __getitem__(self, item):
try:
out = super().__getitem__(item)
except ValueError as err:
if "is unbound" in str(err):
raise ValueError(
"Using first class dimension indices with MemoryMappedTensor "
"isn't supported at the moment."
) from err
raise
if out.untyped_storage().data_ptr() == self.untyped_storage().data_ptr():
out = MemoryMappedTensor(out)
out._handler = self._handler
out._filename = self._filename
out.index = item
out.parent_shape = self.parent_shape
return out
@implement_for("torch", None, "2.0")
def __getitem__(self, item): # noqa: F811
try:
out = super().__getitem__(item)
except ValueError as err:
if "is unbound" in str(err):
raise ValueError(
"Using first class dimension indices with MemoryMappedTensor "
"isn't supported at the moment."
) from err
raise
if out.storage().data_ptr() == self.storage().data_ptr():
out = MemoryMappedTensor(out)
out._handler = self._handler
out._filename = self._filename
out.index = item
out.parent_shape = self.parent_shape
return out
#####################
# File handler
# borrowed from mp.heap
if sys.platform == "win32":
import _winapi
class _FileHandler:
_rand = tempfile._RandomNameSequence()
def __init__(self, size):
self.size = size
for _ in range(100):
name = "pym-%d-%s" % (os.getpid(), next(self._rand))
buf = mmap.mmap(-1, size, tagname=name)
if _winapi.GetLastError() == 0:
break
# We have reopened a preexisting mmap.
buf.close()
else:
raise FileExistsError("Cannot find name for new mmap")
self.name = name
self.buffer = buf
self._state = (self.size, self.name)
def __getstate__(self):
from multiprocessing.context import assert_spawning
assert_spawning(self)
return self._state
def __setstate__(self, state):
self.size, self.name = self._state = state
# Reopen existing mmap
self.buffer = mmap.mmap(-1, self.size, tagname=self.name)
# XXX Temporarily preventing buildbot failures while determining
# XXX the correct long-term fix. See issue 23060
# assert _winapi.GetLastError() == _winapi.ERROR_ALREADY_EXISTS
else:
class _FileHandler:
if sys.platform == "linux":
_dir_candidates = ["/dev/shm"]
else:
_dir_candidates = []
def __init__(self, size, fd=-1):
self.size = size
self.fd = fd
if fd == -1:
self.fd, name = tempfile.mkstemp(
prefix="pym-%d-" % os.getpid(), dir=self._choose_dir(size)
)
os.unlink(name)
util.Finalize(self, os.close, (self.fd,))
os.ftruncate(self.fd, size)
self.buffer = mmap.mmap(self.fd, self.size)
def _choose_dir(self, size):
# Choose a non-storage backed directory if possible,
# to improve performance
for d in self._dir_candidates:
st = os.statvfs(d)
if st.f_bavail * st.f_frsize >= size: # enough free space?
return d
return util.get_temp_dir()
def _reduce_handler(handler):
if handler.fd == -1:
raise ValueError(
"Handler is unpicklable because "
"forking was enabled when it was created"
)
return _rebuild_handler, (handler.size, reduction.DupFd(handler.fd))
def _rebuild_handler(size, dupfd):
detached = dupfd.detach()
return _FileHandler(size, detached)
reduction.register(_FileHandler, _reduce_handler)
def _reduce_memmap(memmap_tensor):
return memmap_tensor.__reduce__()
ForkingPickler.register(MemoryMappedTensor, _reduce_memmap)
def _proc_args_const(*args, **kwargs):
if len(args) > 0:
# then the first (or the N first) args are the shape
if len(args) == 1 and not isinstance(args[0], int):
shape = torch.Size(args[0])
else:
shape = torch.Size(args)
else:
# we should have a "shape" keyword arg
shape = kwargs.pop("shape", None)
if shape is None:
raise TypeError("Could not find the shape argument in the arguments.")
shape = torch.Size(shape)
return (
shape,
kwargs.pop("device", None),
kwargs.pop("dtype", None),
kwargs.pop("fill_value", None),
kwargs.pop("filename", None),
)

89
torch/dict/_pytree.py Normal file
View File

@ -0,0 +1,89 @@
from typing import Any, Dict, List, Tuple
from torch.dict import TensorDictParams
from torch.dict.tensordict import _SubTensorDict, TensorDict
from torch.utils._pytree import Context, register_pytree_node
PYTREE_REGISTERED_TDS = (
TensorDict,
TensorDictParams,
_SubTensorDict,
)
__all__ = []
def _str_to_dict(str_spec: str) -> Tuple[List[str], str]:
assert str_spec[1] == "("
assert str_spec[-1] == ")"
context_and_child_strings = str_spec[2:-1]
child_strings = []
context_strings = []
nested_parentheses = 0
start_index = 0
for i, char in enumerate(context_and_child_strings):
if char == ":":
if nested_parentheses == 0:
context_strings.append(context_and_child_strings[start_index:i])
start_index = i + 1
elif char == "(":
nested_parentheses += 1
elif char == ")":
nested_parentheses -= 1
if nested_parentheses == 0 and char == ",":
child_strings.append(context_and_child_strings[start_index:i])
start_index = i + 1
child_strings.append(context_and_child_strings[start_index:])
return context_strings, ",".join(child_strings)
def _str_to_tensordictdict(str_spec: str) -> Tuple[List[str], str]:
context_and_child_strings = str_spec[2:-1]
child_strings = []
context_strings = []
nested_parentheses = 0
start_index = 0
for i, char in enumerate(context_and_child_strings):
if char == ":":
if nested_parentheses == 0:
context_strings.append(context_and_child_strings[start_index:i])
start_index = i + 1
elif char == "(":
nested_parentheses += 1
elif char == ")":
nested_parentheses -= 1
if nested_parentheses == 0 and char == ",":
child_strings.append(context_and_child_strings[start_index:i])
start_index = i + 1
child_strings.append(context_and_child_strings[start_index:])
return context_strings, ",".join(child_strings)
def _tensordict_flatten(d: TensorDict) -> Tuple[List[Any], Context]:
return list(d.values()), {
"keys": list(d.keys()),
"batch_size": d.batch_size,
"names": d.names,
}
def _tensordictdict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]:
return TensorDict(
dict(zip(context["keys"], values)),
context["batch_size"],
names=context["names"],
)
for cls in PYTREE_REGISTERED_TDS:
register_pytree_node(
cls,
_tensordict_flatten,
_tensordictdict_unflatten,
)

419
torch/dict/_torch_func.py Normal file
View File

@ -0,0 +1,419 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
import functools
from typing import Any, Callable, Sequence, TypeVar
import torch
from torch import Tensor
from ._lazy import LazyStackedTensorDict
from .base import NO_DEFAULT, TensorDictBase
from .tensordict import TensorDict
from .utils import _check_keys, _ErrorInteceptor, DeviceType
TD_HANDLED_FUNCTIONS: dict[Callable, Callable] = {}
LAZY_TD_HANDLED_FUNCTIONS: dict[Callable, Callable] = {}
T = TypeVar("T", bound="TensorDictBase")
def implements_for_td(torch_function: Callable) -> Callable[[Callable], Callable]:
"""Register a torch function override for TensorDict."""
@functools.wraps(torch_function)
def decorator(func: Callable) -> Callable:
TD_HANDLED_FUNCTIONS[torch_function] = func
return func
return decorator
def implements_for_lazy_td(torch_function: Callable) -> Callable[[Callable], Callable]:
"""Register a torch function override for TensorDict."""
@functools.wraps(torch_function)
def decorator(func: Callable) -> Callable:
LAZY_TD_HANDLED_FUNCTIONS[torch_function] = func
return func
return decorator
@implements_for_td(torch.unbind)
def _unbind(td: T, *args: Any, **kwargs: Any) -> tuple[T, ...]:
return td.unbind(*args, **kwargs)
@implements_for_td(torch.gather)
def _gather(
input: T,
dim: int,
index: Tensor,
*,
sparse_grad: bool = False,
out: T | None = None,
) -> T:
if sparse_grad:
raise NotImplementedError(
"sparse_grad=True not implemented for torch.gather(tensordict, ...)"
)
# the index must have as many dims as the tensordict
if not len(index):
raise RuntimeError("Cannot use torch.gather with an empty index")
dim_orig = dim
if dim < 0:
dim = input.batch_dims + dim
if dim > input.batch_dims - 1 or dim < 0:
raise RuntimeError(
f"Cannot gather tensordict with shape {input.shape} along dim {dim_orig}."
)
def _gather_tensor(tensor, dest=None):
index_expand = index
while index_expand.ndim < tensor.ndim:
index_expand = index_expand.unsqueeze(-1)
target_shape = list(tensor.shape)
target_shape[dim] = index_expand.shape[dim]
index_expand = index_expand.expand(target_shape)
out = torch.gather(tensor, dim, index_expand, out=dest)
return out
if out is None:
names = input.names if input._has_names() else None
return TensorDict(
{key: _gather_tensor(value) for key, value in input.items()},
batch_size=index.shape,
names=names,
)
TensorDict(
{key: _gather_tensor(value, out[key]) for key, value in input.items()},
batch_size=index.shape,
)
return out
@implements_for_td(torch.full_like)
def _full_like(td: T, fill_value: float, **kwargs: Any) -> T:
td_clone = td.clone()
for key in td_clone.keys():
td_clone.fill_(key, fill_value)
if "dtype" in kwargs:
raise ValueError("Cannot pass dtype to full_like with TensorDict")
if "device" in kwargs:
td_clone = td_clone.to(kwargs.pop("device"))
if len(kwargs):
raise RuntimeError(
f"keyword arguments {list(kwargs.keys())} are not "
f"supported with full_like with TensorDict"
)
return td_clone
@implements_for_td(torch.zeros_like)
def _zeros_like(td: T, **kwargs: Any) -> T:
td_clone = td._fast_apply(torch.zeros_like)
if "dtype" in kwargs:
raise ValueError("Cannot pass dtype to full_like with TensorDict")
if "device" in kwargs:
td_clone = td_clone.to(kwargs.pop("device"))
if len(kwargs):
raise RuntimeError(
f"keyword arguments {list(kwargs.keys())} are not "
f"supported with full_like with TensorDict"
)
return td_clone
@implements_for_td(torch.ones_like)
def _ones_like(td: T, **kwargs: Any) -> T:
td_clone = td._fast_apply(lambda x: torch.ones_like(x))
if "device" in kwargs:
td_clone = td_clone.to(kwargs.pop("device"))
if len(kwargs):
raise RuntimeError(
f"keyword arguments {list(kwargs.keys())} are not "
f"supported with full_like with TensorDict"
)
return td_clone
@implements_for_td(torch.empty_like)
def _empty_like(td: T, *args, **kwargs) -> T:
try:
tdclone = td.clone()
except Exception as err:
raise RuntimeError(
"The tensordict passed to torch.empty_like cannot be "
"cloned, preventing empty_like to be called. "
"Consider calling tensordict.to_tensordict() first."
) from err
return tdclone._fast_apply(
lambda x: torch.empty_like(x, *args, **kwargs), inplace=True
)
@implements_for_td(torch.clone)
def _clone(td: T, *args: Any, **kwargs: Any) -> T:
return td.clone(*args, **kwargs)
@implements_for_td(torch.squeeze)
def _squeeze(td: T, *args: Any, **kwargs: Any) -> T:
return td.squeeze(*args, **kwargs)
@implements_for_td(torch.unsqueeze)
def _unsqueeze(td: T, *args: Any, **kwargs: Any) -> T:
return td.unsqueeze(*args, **kwargs)
@implements_for_td(torch.masked_select)
def _masked_select(td: T, *args: Any, **kwargs: Any) -> T:
return td.masked_select(*args, **kwargs)
@implements_for_td(torch.permute)
def _permute(td: T, dims: Sequence[int]) -> T:
return td.permute(*dims)
@implements_for_td(torch.cat)
def _cat(
list_of_tensordicts: Sequence[T],
dim: int = 0,
device: DeviceType | None = None,
out: T | None = None,
) -> T:
if not list_of_tensordicts:
raise RuntimeError("list_of_tensordicts cannot be empty")
batch_size = list(list_of_tensordicts[0].batch_size)
if dim < 0:
dim = len(batch_size) + dim
if dim >= len(batch_size):
raise RuntimeError(
f"dim must be in the range 0 <= dim < len(batch_size), got dim"
f"={dim} and batch_size={batch_size}"
)
batch_size[dim] = sum([td.batch_size[dim] for td in list_of_tensordicts])
batch_size = torch.Size(batch_size)
# check that all tensordict match
keys = _check_keys(list_of_tensordicts, strict=True)
if out is None:
out = {}
for key in keys:
with _ErrorInteceptor(
key, "Attempted to concatenate tensors on different devices at key"
):
out[key] = torch.cat(
[td._get_str(key, NO_DEFAULT) for td in list_of_tensordicts], dim
)
if device is None:
device = list_of_tensordicts[0].device
for td in list_of_tensordicts[1:]:
if device == td.device:
continue
else:
device = None
break
names = None
if list_of_tensordicts[0]._has_names():
names = list_of_tensordicts[0].names
return TensorDict(
out, device=device, batch_size=batch_size, _run_checks=False, names=names
)
else:
if out.batch_size != batch_size:
raise RuntimeError(
"out.batch_size and cat batch size must match, "
f"got out.batch_size={out.batch_size} and batch_size"
f"={batch_size}"
)
for key in keys:
with _ErrorInteceptor(
key, "Attempted to concatenate tensors on different devices at key"
):
if isinstance(out, TensorDict):
torch.cat(
[td.get(key) for td in list_of_tensordicts],
dim,
out=out.get(key),
)
else:
out.set_(
key, torch.cat([td.get(key) for td in list_of_tensordicts], dim)
)
return out
@implements_for_lazy_td(torch.cat)
def _lazy_cat(
list_of_tensordicts: Sequence[LazyStackedTensorDict],
dim: int = 0,
out: LazyStackedTensorDict | None = None,
) -> LazyStackedTensorDict:
# why aren't they feeding you?
if not list_of_tensordicts:
raise RuntimeError("list_of_tensordicts cannot be empty")
batch_size = list(list_of_tensordicts[0].batch_size)
if dim < 0:
dim = len(batch_size) + dim
if dim >= len(batch_size):
raise RuntimeError(
f"dim must be in the range 0 <= dim < len(batch_size), got dim"
f"={dim} and batch_size={batch_size}"
)
stack_dim = list_of_tensordicts[0].stack_dim
if any((td.stack_dim != stack_dim) for td in list_of_tensordicts):
raise RuntimeError("cat lazy stacked tds must have same stack dim")
batch_size[dim] = sum(td.batch_size[dim] for td in list_of_tensordicts)
batch_size = torch.Size(batch_size)
new_dim = dim
if dim > stack_dim:
new_dim = dim - 1
if out is None:
out = []
if dim == stack_dim: # if dim is stack, just add all to the same list
for lazy_td in list_of_tensordicts:
out += lazy_td.tensordicts
else:
for i in range(len(list_of_tensordicts[0].tensordicts)):
out.append(
torch.cat(
[lazy_td.tensordicts[i] for lazy_td in list_of_tensordicts],
new_dim,
)
)
return LazyStackedTensorDict(*out, stack_dim=stack_dim)
else:
if not isinstance(out, LazyStackedTensorDict):
return _cat(list_of_tensordicts, dim=dim, out=out)
if out.batch_size != batch_size:
raise RuntimeError(
"out.batch_size and cat batch size must match, "
f"got out.batch_size={out.batch_size} and batch_size"
f"={batch_size}"
)
if out.stack_dim != dim:
index_base = (slice(None),) * out.stack_dim
for i, sub_dest in enumerate(out.tensordicts):
index = index_base + (i,)
tds_to_cat = [_td[index] for _td in list_of_tensordicts]
torch.cat(tds_to_cat, dim, out=sub_dest)
else:
init_idx = 0
for td_in in list_of_tensordicts:
sub_dest = out.tensordicts[init_idx : init_idx + td_in.shape[dim]]
init_idx += init_idx + td_in.shape[dim]
torch.stack(sub_dest, out.stack_dim).update(td_in, inplace=True)
return out
@implements_for_td(torch.stack)
def _stack(
list_of_tensordicts: Sequence[TensorDictBase],
dim: int = 0,
device: DeviceType | None = None,
out: T | None = None,
strict: bool = False,
contiguous: bool = False,
) -> T:
if not list_of_tensordicts:
raise RuntimeError("list_of_tensordicts cannot be empty")
batch_size = list_of_tensordicts[0].batch_size
if dim < 0:
dim = len(batch_size) + dim + 1
for td in list_of_tensordicts[1:]:
if td.batch_size != list_of_tensordicts[0].batch_size:
raise RuntimeError(
"stacking tensordicts requires them to have congruent batch sizes, "
f"got td1.batch_size={td.batch_size} and td2.batch_size="
f"{list_of_tensordicts[0].batch_size}"
)
# check that all tensordict match
keys = _check_keys(list_of_tensordicts)
result_batch_size = list(batch_size)
result_batch_size.insert(dim, len(list_of_tensordicts))
result_batch_size = torch.Size(result_batch_size)
if out is None:
device = list_of_tensordicts[0].device
out = {}
for key in keys:
with _ErrorInteceptor(
key, "Attempted to stack tensors on different devices at key"
):
out[key] = torch.stack(
[_tensordict.get(key) for _tensordict in list_of_tensordicts],
dim,
)
return TensorDict(
out,
batch_size=result_batch_size,
device=device,
_run_checks=False,
)
else:
if out.batch_size != result_batch_size:
raise RuntimeError(
"out.batch_size and stacked batch size must match, "
f"got out.batch_size={out.batch_size} and resulting batch_size"
f"={result_batch_size}"
)
out_keys = set(out.keys())
if strict:
in_keys = set(keys)
if len(out_keys - in_keys) > 0:
raise RuntimeError(
"The output tensordict has keys that are missing in the "
"tensordict that has to be written: {out_keys - in_keys}. "
"As per the call to `stack(..., strict=True)`, this "
"is not permitted."
)
elif len(in_keys - out_keys) > 0:
raise RuntimeError(
"The resulting tensordict has keys that are missing in "
f"its destination: {in_keys - out_keys}. As per the call "
"to `stack(..., strict=True)`, this is not permitted."
)
out._stack_onto_(list_of_tensordicts, dim)
return out
@implements_for_td(torch.split)
def _split(
td: TensorDict, split_size_or_sections: int | list[int], dim: int = 0
) -> list[TensorDictBase]:
return td.split(split_size_or_sections, dim)
@implements_for_td(torch.where)
def where(condition, input, other, *, out=None):
"""Return a ``TensorDict`` of elements selected from either input or other, depending on condition.
Args:
condition (BoolTensor): When ``True`` (nonzero), yield ``input``, otherwise yield ``other``.
input (TensorDictBase or Scalar): value (if ``input`` is a scalar) or values selected at indices where condition is ``True``.
other (TensorDictBase or Scalar): value (if ``other`` is a scalar) or values selected at indices where condition is ``False``.
out (Tensor, optional): the output ``TensorDictBase`` instance.
"""
return input.where(condition, other, out=out)

4500
torch/dict/base.py Normal file

File diff suppressed because it is too large Load Diff

338
torch/dict/functional.py Normal file
View File

@ -0,0 +1,338 @@
from __future__ import annotations
from typing import Sequence, TypeVar
import torch
from ._lazy import LazyStackedTensorDict
from .base import _is_tensor_collection, CompatibleType, TensorDictBase
from .tensordict import TensorDict
from .utils import _check_keys, _shape, DeviceType
T = TypeVar("T", bound="TensorDictBase")
def pad(tensordict: T, pad_size: Sequence[int], value: float = 0.0) -> T:
"""Pads all tensors in a tensordict along the batch dimensions with a constant value, returning a new tensordict.
Args:
tensordict (TensorDict): The tensordict to pad
pad_size (Sequence[int]): The padding size by which to pad some batch
dimensions of the tensordict, starting from the first dimension and
moving forward. [len(pad_size) / 2] dimensions of the batch size will
be padded. For example to pad only the first dimension, pad has the form
(padding_left, padding_right). To pad two dimensions,
(padding_left, padding_right, padding_top, padding_bottom) and so on.
pad_size must be even and less than or equal to twice the number of batch dimensions.
value (float, optional): The fill value to pad by, default 0.0
Returns:
A new TensorDict padded along the batch dimensions
Examples:
>>> from torch.dict import TensorDict, pad
>>> import torch
>>> td = TensorDict({'a': torch.ones(3, 4, 1),
... 'b': torch.ones(3, 4, 1, 1)}, batch_size=[3, 4])
>>> dim0_left, dim0_right, dim1_left, dim1_right = [0, 1, 0, 2]
>>> padded_td = pad(td, [dim0_left, dim0_right, dim1_left, dim1_right], value=0.0)
>>> print(padded_td.batch_size)
torch.Size([4, 6])
>>> print(padded_td.get("a").shape)
torch.Size([4, 6, 1])
>>> print(padded_td.get("b").shape)
torch.Size([4, 6, 1, 1])
"""
if len(pad_size) > 2 * len(tensordict.batch_size):
raise RuntimeError(
"The length of pad_size must be <= 2 * the number of batch dimensions"
)
if len(pad_size) % 2:
raise RuntimeError("pad_size must have an even number of dimensions")
new_batch_size = list(tensordict.batch_size)
for i in range(len(pad_size)):
new_batch_size[i // 2] += pad_size[i]
reverse_pad = pad_size[::-1]
for i in range(0, len(reverse_pad), 2):
reverse_pad[i], reverse_pad[i + 1] = reverse_pad[i + 1], reverse_pad[i]
out = TensorDict(
{}, torch.Size(new_batch_size), device=tensordict.device, _run_checks=False
)
for key, tensor in tensordict.items():
cur_pad = reverse_pad
if len(pad_size) < len(_shape(tensor)) * 2:
cur_pad = [0] * (len(_shape(tensor)) * 2 - len(pad_size)) + reverse_pad
if _is_tensor_collection(tensor.__class__):
padded = pad(tensor, pad_size, value)
else:
padded = torch.nn.functional.pad(tensor, cur_pad, value=value)
out.set(key, padded)
return out
def pad_sequence(
list_of_tensordicts: Sequence[T],
batch_first: bool = True,
padding_value: float = 0.0,
out: T | None = None,
device: DeviceType | None = None,
return_mask: bool | None = False,
) -> T:
"""Pads a list of tensordicts in order for them to be stacked together in a contiguous format.
Args:
list_of_tensordicts (List[TensorDictBase]): the list of instances to pad and stack.
batch_first (bool, optional): the ``batch_first`` correspondant of :func:`torch.nn.utils.rnn.pad_sequence`.
Defaults to ``True``.
padding_value (number, optional): the padding value. Defaults to ``0.0``.
out (TensorDictBase, optional): if provided, the destination where the data will be
written.
device (device compatible type, optional): if provded, the device where the
TensorDict output will be created.
return_mask (bool, optional): if ``True``, a "mask" entry will be returned.
It contains the mask of valid values in the stacked tensordict.
Examples:
>>> list_td = [
... TensorDict({"a": torch.zeros((3,))}, []),
... TensorDict({"a": torch.zeros((4,))}, []),
... ]
>>> padded_td = pad_sequence(list_td)
>>> print(padded_td)
TensorDict(
fields={
a: Tensor(shape=torch.Size([2, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
"""
if not list_of_tensordicts:
raise RuntimeError("list_of_tensordicts cannot be empty")
# check that all tensordict match
if return_mask:
list_of_tensordicts = [
td.clone(False).set("mask", torch.ones(td.shape, dtype=torch.bool))
for td in list_of_tensordicts
]
keys = _check_keys(list_of_tensordicts, leaves_only=True, include_nested=True)
shape = max(len(td) for td in list_of_tensordicts)
if shape == 0:
shape = [
len(list_of_tensordicts),
]
elif batch_first:
shape = [len(list_of_tensordicts), shape]
else:
shape = [shape, len(list_of_tensordicts)]
if out is None:
out = TensorDict(
{}, batch_size=torch.Size(shape), device=device, _run_checks=False
)
for key in keys:
try:
out.set(
key,
torch.nn.utils.rnn.pad_sequence(
[td.get(key) for td in list_of_tensordicts],
batch_first=batch_first,
padding_value=padding_value,
),
)
except Exception as err:
raise RuntimeError(f"pad_sequence failed for key {key}") from err
return out
else:
for key in keys:
out.set_(
key,
torch.nn.utils.rnn.pad_sequence(
[td.get(key) for td in list_of_tensordicts],
batch_first=batch_first,
padding_value=padding_value,
),
)
return out
def merge_tensordicts(*tensordicts: T) -> T:
"""Merges tensordicts together."""
if len(tensordicts) < 2:
raise RuntimeError(
f"at least 2 tensordicts must be provided, got" f" {len(tensordicts)}"
)
d = tensordicts[0].to_dict()
batch_size = tensordicts[0].batch_size
for td in tensordicts[1:]:
d.update(td.to_dict())
if td.batch_dims < len(batch_size):
batch_size = td.batch_size
return TensorDict(d, batch_size, device=td.device, _run_checks=False)
def dense_stack_tds(
td_list: Sequence[TensorDictBase] | LazyStackedTensorDict,
dim: int = None,
) -> T:
"""Densely stack a list of :class:`~tensordict.TensorDictBase` objects (or a :class:`~tensordict.LazyStackedTensorDict`) given that they have the same structure.
This function is called with a list of :class:`~tensordict.TensorDictBase` (either passed directly or obtrained from
a :class:`~tensordict.LazyStackedTensorDict`).
Instead of calling ``torch.stack(td_list)``, which would return a :class:`~tensordict.LazyStackedTensorDict`,
this function expands the first element of the input list and stacks the input list onto that element.
This works only when all the elements of the input list have the same structure.
The :class:`~tensordict.TensorDictBase` returned will have the same type of the elements of the input list.
This function is useful when some of the :class:`~tensordict.TensorDictBase` objects that need to be stacked
are :class:`~tensordict.LazyStackedTensorDict` or have :class:`~tensordict.LazyStackedTensorDict`
among entries (or nested entries).
In those cases, calling ``torch.stack(td_list).to_tensordict()`` is infeasible.
Thus, this function provides an alternative for densely stacking the list provided.
Args:
td_list (List of TensorDictBase or LazyStackedTensorDict): the tds to stack.
dim (int, optional): the dimension to stack them.
If td_list is a LazyStackedTensorDict, it will be retrieved automatically.
Examples:
>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict import dense_stack_tds
>>> from tensordict.tensordict import assert_allclose_td
>>> td0 = TensorDict({"a": torch.zeros(3)},[])
>>> td1 = TensorDict({"a": torch.zeros(4), "b": torch.zeros(2)},[])
>>> td_lazy = torch.stack([td0, td1], dim=0)
>>> td_container = TensorDict({"lazy": td_lazy}, [])
>>> td_container_clone = td_container.clone()
>>> td_stack = torch.stack([td_container, td_container_clone], dim=0)
>>> td_stack
LazyStackedTensorDict(
fields={
lazy: LazyStackedTensorDict(
fields={
a: Tensor(shape=torch.Size([2, 2, -1]), device=cpu, dtype=torch.float32, is_shared=False)},
exclusive_fields={
},
batch_size=torch.Size([2, 2]),
device=None,
is_shared=False,
stack_dim=0)},
exclusive_fields={
},
batch_size=torch.Size([2]),
device=None,
is_shared=False,
stack_dim=0)
>>> td_stack = dense_stack_tds(td_stack) # Automatically use the LazyStackedTensorDict stack_dim
TensorDict(
fields={
lazy: LazyStackedTensorDict(
fields={
a: Tensor(shape=torch.Size([2, 2, -1]), device=cpu, dtype=torch.float32, is_shared=False)},
exclusive_fields={
1 ->
b: Tensor(shape=torch.Size([2, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([2, 2]),
device=None,
is_shared=False,
stack_dim=1)},
batch_size=torch.Size([2]),
device=None,
is_shared=False)
# Note that
# (1) td_stack is now a TensorDict
# (2) this has pushed the stack_dim of "lazy" (0 -> 1)
# (3) this has revealed the exclusive keys.
>>> assert_allclose_td(td_stack, dense_stack_tds([td_container, td_container_clone], dim=0))
# This shows it is the same to pass a list or a LazyStackedTensorDict
"""
if isinstance(td_list, LazyStackedTensorDict):
dim = td_list.stack_dim
td_list = td_list.tensordicts
elif dim is None:
raise ValueError(
"If a list of tensordicts is provided, stack_dim must not be None"
)
shape = list(td_list[0].shape)
shape.insert(dim, len(td_list))
result = td_list[0].unsqueeze(dim)
result = result.expand(shape)
result = result.clone()
return LazyStackedTensorDict.maybe_dense_stack(td_list, dim=dim, out=result)
def make_tensordict(
input_dict: dict[str, CompatibleType] | None = None,
batch_size: Sequence[int] | torch.Size | int | None = None,
device: DeviceType | None = None,
**kwargs: CompatibleType, # source
) -> TensorDict:
"""Returns a TensorDict created from the keyword arguments or an input dictionary.
If ``batch_size`` is not specified, returns the maximum batch size possible.
This function works on nested dictionaries too, or can be used to determine the
batch-size of a nested tensordict.
Args:
input_dict (dictionary, optional): a dictionary to use as a data source
(nested keys compatible).
**kwargs (TensorDict or torch.Tensor): keyword arguments as data source
(incompatible with nested keys).
batch_size (iterable of int, optional): a batch size for the tensordict.
device (torch.device or compatible type, optional): a device for the TensorDict.
Examples:
>>> input_dict = {"a": torch.randn(3, 4), "b": torch.randn(3)}
>>> print(make_tensordict(input_dict))
TensorDict(
fields={
a: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3]),
device=None,
is_shared=False)
>>> # alternatively
>>> td = make_tensordict(**input_dict)
>>> # nested dict: the nested TensorDict can have a different batch-size
>>> # as long as its leading dims match.
>>> input_dict = {"a": torch.randn(3), "b": {"c": torch.randn(3, 4)}}
>>> print(make_tensordict(input_dict))
TensorDict(
fields={
a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
b: TensorDict(
fields={
c: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3, 4]),
device=None,
is_shared=False)},
batch_size=torch.Size([3]),
device=None,
is_shared=False)
>>> # we can also use this to work out the batch sie of a tensordict
>>> input_td = TensorDict({"a": torch.randn(3), "b": {"c": torch.randn(3, 4)}}, [])
>>> print(make_tensordict(input_td))
TensorDict(
fields={
a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
b: TensorDict(
fields={
c: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3, 4]),
device=None,
is_shared=False)},
batch_size=torch.Size([3]),
device=None,
is_shared=False)
"""
if input_dict is not None:
kwargs.update(input_dict)
return TensorDict.from_dict(kwargs, batch_size=batch_size, device=device)

1174
torch/dict/params.py Normal file

File diff suppressed because it is too large Load Diff

1361
torch/dict/tensorclass.py Normal file

File diff suppressed because it is too large Load Diff

2813
torch/dict/tensordict.py Normal file

File diff suppressed because it is too large Load Diff

1526
torch/dict/utils.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -89,57 +89,79 @@ def _untie_named_tensors_map(
@contextlib.contextmanager
def _reparametrize_module(
module: "torch.nn.Module",
parameters_and_buffers: Dict[str, Tensor],
parameters_and_buffers: Union[Dict[str, Tensor], "TensorDictBase"],
*,
tie_weights: bool = False,
strict: bool = False,
) -> Iterator[None]:
if tie_weights:
untied_parameters_and_buffers = _untie_named_tensors_map(
module, parameters_and_buffers
)
from torch.dict import TensorDictBase
if isinstance(parameters_and_buffers, TensorDictBase):
if strict:
raise NotImplementedError
if not tie_weights:
raise NotImplementedError
orig_parameters_and_buffers = parameters_and_buffers.empty()
try:
orig_parameters_and_buffers = parameters_and_buffers.to_module(
module, return_swap=True
)
yield
finally:
# tensordict is locked by default in this case, so we unlock it as we can't tell if an inplace update
# can be done (most likely not)
orig_parameters_and_buffers.to_module(
module, return_swap=True, swap_dest=parameters_and_buffers
)
else:
untied_parameters_and_buffers = parameters_and_buffers
accessor = NamedMemberAccessor(module)
if strict:
missing_keys, unexpected_keys = accessor.check_keys(
untied_parameters_and_buffers
)
error_msgs = []
if len(unexpected_keys) > 0:
error_msgs.append(
f"Unexpected key(s): {', '.join(map(repr, unexpected_keys))}."
if tie_weights:
untied_parameters_and_buffers = _untie_named_tensors_map(
module, parameters_and_buffers
)
if len(missing_keys) > 0:
error_msgs.append(f"Missing key(s): {', '.join(map(repr, missing_keys))}.")
if len(error_msgs) > 0:
raise RuntimeError(
"Error(s) in reparametrizing for {}:\n\t{}".format(
module._get_name(), "\n\t".join(error_msgs)
else:
untied_parameters_and_buffers = parameters_and_buffers
accessor = NamedMemberAccessor(module)
if strict:
missing_keys, unexpected_keys = accessor.check_keys(
untied_parameters_and_buffers
)
error_msgs = []
if len(unexpected_keys) > 0:
error_msgs.append(
f"Unexpected key(s): {', '.join(map(repr, unexpected_keys))}."
)
if len(missing_keys) > 0:
error_msgs.append(
f"Missing key(s): {', '.join(map(repr, missing_keys))}."
)
if len(error_msgs) > 0:
raise RuntimeError(
"Error(s) in reparametrizing for {}:\n\t{}".format(
module._get_name(), "\n\t".join(error_msgs)
)
)
)
orig_parameters_and_buffers: Dict[str, Tensor] = {}
try:
orig_parameters_and_buffers, _ = accessor.swap_tensors_dict(
untied_parameters_and_buffers, allow_missing=True
)
yield
finally:
new_parameters_and_buffers, _ = accessor.swap_tensors_dict(
orig_parameters_and_buffers, allow_missing=True
)
# Sometimes the module is not completely stateless and has some in-place modifications on
# the _parameters and _buffers dictionaries.
# Write the changed parameters and buffers back to the original dict.
parameters_and_buffers.update(
{
k: new_parameters_and_buffers[k]
for k in parameters_and_buffers
if k in new_parameters_and_buffers
}
)
orig_parameters_and_buffers: Dict[str, Tensor] = {}
try:
orig_parameters_and_buffers, _ = accessor.swap_tensors_dict(
untied_parameters_and_buffers, allow_missing=True
)
yield
finally:
new_parameters_and_buffers, _ = accessor.swap_tensors_dict(
orig_parameters_and_buffers, allow_missing=True
)
# Sometimes the module is not completely stateless and has some in-place modifications on
# the _parameters and _buffers dictionaries.
# Write the changed parameters and buffers back to the original dict.
parameters_and_buffers.update(
{
k: new_parameters_and_buffers[k]
for k in parameters_and_buffers
if k in new_parameters_and_buffers
}
)
def functional_call(
@ -228,7 +250,7 @@ def functional_call(
def _functional_call(
module: "torch.nn.Module",
parameters_and_buffers: Dict[str, Tensor],
parameters_and_buffers: Union[Dict[str, Tensor], "TensorDictBase"],
args: Union[Any, Tuple],
kwargs: Optional[Dict[str, Any]] = None,
*,