mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49734 RFC: https://github.com/pytorch/rfcs/pull/11 This PR add the basic logic to handle forward grad as dual Tensors. It contains the following: - Mechanism to save dual state on a Tensor and clear it up when the dual level ends - C++ and python user facing API - Updated view system that is able to track both forward and backward views The current PR has the following limitations: - Extensive tests are in the next PR in the stack as formulas are needed to write full tests. - Only the manual formulas have been audited and no other formula is actually implemented here (they are in the next PR in the stack) - Only level 0 is allowed for now. This was discussed and agreed that it is not needed for the first version of this PR. - We can save one ViewInfo creation when both the forward and backward views have the same base. This can be done by adding a boolean flag to the DifferentiableViewMeta and extra logic in the `as_view` method. This is left out to keep this PR concise. - We can skip tracking forward views if the base has a forward grad. This can be done by adding extra logic in the `as_view` method. This is left out to keep this PR concise. Reading guide: - Updated view handling in [gen_variable_type.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-f6553cec68caeaea36f6c8b14ff76a6d39dfd774e0ea9ef2f76e8d81fd9af5df), [VariableTypeUtils.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-ec71cfa45954dece1236c661d170e6341879c5be637f4abf52e826d61b40695a), [variable.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-60e3bfe444e89efc7149f25b38e472710525984789934ab83f1bd5671b8ff285) (skip code below "[Forward Grad View]" for now), [variable.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-1604bcd0e4350ed99ec45e437cee7ac9ebe337392c9ea16a236247aeeb35b02bR266-R542) and [custom_function.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-dd85f452082b5bb6612bbc12adb496f8827defa228509f7b493de1d517522d5d). This introduces the new ViewInfo to hold view informations shared for forward and backward. It also updates the differentiable view meta to use this. And it updates the as_view function to handle both forward and backward view. - New forward grad class that handle storing gradients and tracking at each level [forward_grad.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-c6c5b9ab2d7e5dde4102495faa1b6bbbfc23aa3e47deb7359c0bfe1eb004c0cb), [forward_grad.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-de2ab54ade7312701850d71a119a4f4ee4b9fc5a9c42a467cdd4e73c033531dd) and [build_variables.bzl](https://github.com/pytorch/pytorch/pull/49097/files#diff-dfdfa2efb17beddfd9094524f95351fd197db6c8857e96b436fb599870359325). EDIT: These files also contain the new flag to globally disable forward AD that allows us to reduce performance issues while this is in development. - Lowest level API and binding between Tensor and AutogradMeta in [TensorBody.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-7554853205392fa743357bf845ecc350a974ec049383248c12daaf2f4de04911), [TensorImpl.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-052bd9150ef8e09289ddf644b5a6830ede49207201cd41728f6d7cc6d9cead94), [TensorImpl.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-a15aae4cf23da44970db7cece62ff981265575c798c62f7b52d87c8809dfe2e1) and the rest of [variable.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-60e3bfe444e89efc7149f25b38e472710525984789934ab83f1bd5671b8ff285R557-R677) - API to access the forward primal that needs to be a differentiable function (and so in native_functions.yaml) [native_functions.yaml](https://github.com/pytorch/pytorch/pull/49097/files#diff-2f3dbd85efb9b5172f2264eedd3be47dd765e6ab7cc8bf3ade5e62c28ae35991) [NamedRegistrations.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-69bd3bea510c9b64e1633fa18c3ea63d4b8348dbad3a78ad9de844ab3e43dc1d), [VariableMethodsStub.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-23f5fcb737a2b289811fe0f4b65aef775e7c824b2e629ecd343df51405cd434f), [derivatives.yaml](https://github.com/pytorch/pytorch/pull/49097/files#diff-e4c2f99a2404e98c3586e07425da73008f36b1bada790648a7297af141d37f8c), [gen_python_functions.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-e4c2f99a2404e98c3586e07425da73008f36b1bada790648a7297af141d37f8c), [gen_trace_type.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-54e0b976027bf8debefb959ff360b89ae93466970c843365b1b3a03806d868ce), [TraceTypeManual.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-f34636741ad4a23d018e0c289bc750c3bad887b45660e1d6eaf440d234a78fbf) and [part of VariableTypeManual.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-6e19a1bce8cbdba8714b6e2c794a76bc0864b64a49cfa757cb0b5afdc937d1a4R198-R243) - c++ API [autograd.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-349028fbe8291a965a7a263c323b208fe071c35c66179ee997ef84fa81aa4b1e), [autograd.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-a3fe908d67dfec16a1fcde300de68b0701bf68b88db7451f29f2bee255cf30c9) - python binding [init.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-c58a67c85191c22c9b3bb439117d8053edfd9dea839fa010cf967d404c3c630d) - python API [forward_ad.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-a4efad4ba18fffdfb264c21e5475997a24a743089a899f8ec1a5ff962c6738d9), [autograd/__init__.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-743abcafd32ad0e69f39ac5a91df4197b7e1921c135cacee7ef6dc829a8a7af8) - c++ and python printing [Formatting.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-881dba501e71662e2e4818b4b016f739b344c8aed2f5edc6b871eda47a2aced0), [_tensor_str.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-a7911f8d5e73adbff914d99fd7818ace2a7030b6a3748abe06ec6fc6e3df9cc3) - Utility for formulas and updated manual functions to respect new view system as well as forward grad [FunctionsManual.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-6378bb6dc81a64dab676d61731341fa5d1088418f32a1473a33a0ccfc2357dc1), [FunctionsManual.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-4adbd88239afcd60e8198aab65d4f5e43b62314e34b80551e997a1ea503adea5) [rest of VariableTypeManual.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-6e19a1bce8cbdba8714b6e2c794a76bc0864b64a49cfa757cb0b5afdc937d1a4R264-R433) - Ensure SavedVariable save forward grad properly [saved_variable.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-c1b8039d776241abe177d5aa99b79dd9489a9b3e529da8ab24c2e386c1238ae2), [saved_variable.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-cc9fba479b5beae06b2eea2e390d17796e0341c5b037a20b5bcaccbb0c341030) Test Plan: Imported from OSS Reviewed By: gchanan Differential Revision: D25678797 Pulled By: albanD fbshipit-source-id: 3d58550c11b5f58b9b73fd30596d042b857fb9dd
384 lines
16 KiB
Python
384 lines
16 KiB
Python
import math
|
|
import torch
|
|
from torch._six import inf
|
|
from typing import Optional
|
|
|
|
|
|
class __PrinterOptions(object):
|
|
precision: int = 4
|
|
threshold: float = 1000
|
|
edgeitems: int = 3
|
|
linewidth: int = 80
|
|
sci_mode: Optional[bool] = None
|
|
|
|
|
|
PRINT_OPTS = __PrinterOptions()
|
|
|
|
|
|
# We could use **kwargs, but this will give better docs
|
|
def set_printoptions(
|
|
precision=None,
|
|
threshold=None,
|
|
edgeitems=None,
|
|
linewidth=None,
|
|
profile=None,
|
|
sci_mode=None
|
|
):
|
|
r"""Set options for printing. Items shamelessly taken from NumPy
|
|
|
|
Args:
|
|
precision: Number of digits of precision for floating point output
|
|
(default = 4).
|
|
threshold: Total number of array elements which trigger summarization
|
|
rather than full `repr` (default = 1000).
|
|
edgeitems: Number of array items in summary at beginning and end of
|
|
each dimension (default = 3).
|
|
linewidth: The number of characters per line for the purpose of
|
|
inserting line breaks (default = 80). Thresholded matrices will
|
|
ignore this parameter.
|
|
profile: Sane defaults for pretty printing. Can override with any of
|
|
the above options. (any one of `default`, `short`, `full`)
|
|
sci_mode: Enable (True) or disable (False) scientific notation. If
|
|
None (default) is specified, the value is defined by
|
|
`torch._tensor_str._Formatter`. This value is automatically chosen
|
|
by the framework.
|
|
"""
|
|
if profile is not None:
|
|
if profile == "default":
|
|
PRINT_OPTS.precision = 4
|
|
PRINT_OPTS.threshold = 1000
|
|
PRINT_OPTS.edgeitems = 3
|
|
PRINT_OPTS.linewidth = 80
|
|
elif profile == "short":
|
|
PRINT_OPTS.precision = 2
|
|
PRINT_OPTS.threshold = 1000
|
|
PRINT_OPTS.edgeitems = 2
|
|
PRINT_OPTS.linewidth = 80
|
|
elif profile == "full":
|
|
PRINT_OPTS.precision = 4
|
|
PRINT_OPTS.threshold = inf
|
|
PRINT_OPTS.edgeitems = 3
|
|
PRINT_OPTS.linewidth = 80
|
|
|
|
if precision is not None:
|
|
PRINT_OPTS.precision = precision
|
|
if threshold is not None:
|
|
PRINT_OPTS.threshold = threshold
|
|
if edgeitems is not None:
|
|
PRINT_OPTS.edgeitems = edgeitems
|
|
if linewidth is not None:
|
|
PRINT_OPTS.linewidth = linewidth
|
|
PRINT_OPTS.sci_mode = sci_mode
|
|
|
|
|
|
class _Formatter(object):
|
|
def __init__(self, tensor):
|
|
self.floating_dtype = tensor.dtype.is_floating_point
|
|
self.int_mode = True
|
|
self.sci_mode = False
|
|
self.max_width = 1
|
|
|
|
with torch.no_grad():
|
|
tensor_view = tensor.reshape(-1)
|
|
|
|
if not self.floating_dtype:
|
|
for value in tensor_view:
|
|
value_str = '{}'.format(value)
|
|
self.max_width = max(self.max_width, len(value_str))
|
|
|
|
else:
|
|
nonzero_finite_vals = torch.masked_select(tensor_view, torch.isfinite(tensor_view) & tensor_view.ne(0))
|
|
|
|
if nonzero_finite_vals.numel() == 0:
|
|
# no valid number, do nothing
|
|
return
|
|
|
|
# Convert to double for easy calculation. HalfTensor overflows with 1e8, and there's no div() on CPU.
|
|
nonzero_finite_abs = nonzero_finite_vals.abs().double()
|
|
nonzero_finite_min = nonzero_finite_abs.min().double()
|
|
nonzero_finite_max = nonzero_finite_abs.max().double()
|
|
|
|
for value in nonzero_finite_vals:
|
|
if value != torch.ceil(value):
|
|
self.int_mode = False
|
|
break
|
|
|
|
if self.int_mode:
|
|
# in int_mode for floats, all numbers are integers, and we append a decimal to nonfinites
|
|
# to indicate that the tensor is of floating type. add 1 to the len to account for this.
|
|
if nonzero_finite_max / nonzero_finite_min > 1000. or nonzero_finite_max > 1.e8:
|
|
self.sci_mode = True
|
|
for value in nonzero_finite_vals:
|
|
value_str = ('{{:.{}e}}').format(PRINT_OPTS.precision).format(value)
|
|
self.max_width = max(self.max_width, len(value_str))
|
|
else:
|
|
for value in nonzero_finite_vals:
|
|
value_str = ('{:.0f}').format(value)
|
|
self.max_width = max(self.max_width, len(value_str) + 1)
|
|
else:
|
|
# Check if scientific representation should be used.
|
|
if nonzero_finite_max / nonzero_finite_min > 1000.\
|
|
or nonzero_finite_max > 1.e8\
|
|
or nonzero_finite_min < 1.e-4:
|
|
self.sci_mode = True
|
|
for value in nonzero_finite_vals:
|
|
value_str = ('{{:.{}e}}').format(PRINT_OPTS.precision).format(value)
|
|
self.max_width = max(self.max_width, len(value_str))
|
|
else:
|
|
for value in nonzero_finite_vals:
|
|
value_str = ('{{:.{}f}}').format(PRINT_OPTS.precision).format(value)
|
|
self.max_width = max(self.max_width, len(value_str))
|
|
|
|
if PRINT_OPTS.sci_mode is not None:
|
|
self.sci_mode = PRINT_OPTS.sci_mode
|
|
|
|
def width(self):
|
|
return self.max_width
|
|
|
|
def format(self, value):
|
|
if self.floating_dtype:
|
|
if self.sci_mode:
|
|
ret = ('{{:{}.{}e}}').format(self.max_width, PRINT_OPTS.precision).format(value)
|
|
elif self.int_mode:
|
|
ret = '{:.0f}'.format(value)
|
|
if not (math.isinf(value) or math.isnan(value)):
|
|
ret += '.'
|
|
else:
|
|
ret = ('{{:.{}f}}').format(PRINT_OPTS.precision).format(value)
|
|
else:
|
|
ret = '{}'.format(value)
|
|
return (self.max_width - len(ret)) * ' ' + ret
|
|
|
|
|
|
def _scalar_str(self, formatter1, formatter2=None):
|
|
if formatter2 is not None:
|
|
real_str = _scalar_str(self.real, formatter1)
|
|
imag_str = _scalar_str(self.imag, formatter2) + "j"
|
|
if self.imag < 0:
|
|
return real_str + imag_str.lstrip()
|
|
else:
|
|
return real_str + "+" + imag_str.lstrip()
|
|
else:
|
|
return formatter1.format(self.item())
|
|
|
|
def _vector_str(self, indent, summarize, formatter1, formatter2=None):
|
|
# length includes spaces and comma between elements
|
|
element_length = formatter1.width() + 2
|
|
if formatter2 is not None:
|
|
# width for imag_formatter + an extra j for complex
|
|
element_length += formatter2.width() + 1
|
|
|
|
elements_per_line = max(1, int(math.floor((PRINT_OPTS.linewidth - indent) / (element_length))))
|
|
char_per_line = element_length * elements_per_line
|
|
|
|
def _val_formatter(val, formatter1=formatter1, formatter2=formatter2):
|
|
if formatter2 is not None:
|
|
real_str = formatter1.format(val.real)
|
|
imag_str = formatter2.format(val.imag) + "j"
|
|
if val.imag < 0:
|
|
return real_str + imag_str.lstrip()
|
|
else:
|
|
return real_str + "+" + imag_str.lstrip()
|
|
else:
|
|
return formatter1.format(val)
|
|
|
|
if summarize and self.size(0) > 2 * PRINT_OPTS.edgeitems:
|
|
data = ([_val_formatter(val) for val in self[:PRINT_OPTS.edgeitems].tolist()] +
|
|
[' ...'] +
|
|
[_val_formatter(val) for val in self[-PRINT_OPTS.edgeitems:].tolist()])
|
|
else:
|
|
data = [_val_formatter(val) for val in self.tolist()]
|
|
|
|
data_lines = [data[i:i + elements_per_line] for i in range(0, len(data), elements_per_line)]
|
|
lines = [', '.join(line) for line in data_lines]
|
|
return '[' + (',' + '\n' + ' ' * (indent + 1)).join(lines) + ']'
|
|
|
|
# formatter2 is only used for printing complex tensors.
|
|
# For complex tensors, formatter1 and formatter2 are the formatters for tensor.real
|
|
# and tensor.imag respesectively
|
|
def _tensor_str_with_formatter(self, indent, summarize, formatter1, formatter2=None):
|
|
dim = self.dim()
|
|
|
|
if dim == 0:
|
|
return _scalar_str(self, formatter1, formatter2)
|
|
|
|
if dim == 1:
|
|
return _vector_str(self, indent, summarize, formatter1, formatter2)
|
|
|
|
if summarize and self.size(0) > 2 * PRINT_OPTS.edgeitems:
|
|
slices = ([_tensor_str_with_formatter(self[i], indent + 1, summarize, formatter1, formatter2)
|
|
for i in range(0, PRINT_OPTS.edgeitems)] +
|
|
['...'] +
|
|
[_tensor_str_with_formatter(self[i], indent + 1, summarize, formatter1, formatter2)
|
|
for i in range(len(self) - PRINT_OPTS.edgeitems, len(self))])
|
|
else:
|
|
slices = [_tensor_str_with_formatter(self[i], indent + 1, summarize, formatter1, formatter2)
|
|
for i in range(0, self.size(0))]
|
|
|
|
tensor_str = (',' + '\n' * (dim - 1) + ' ' * (indent + 1)).join(slices)
|
|
return '[' + tensor_str + ']'
|
|
|
|
def _tensor_str(self, indent):
|
|
if self.numel() == 0:
|
|
return '[]'
|
|
|
|
if self.has_names():
|
|
# There are two main codepaths (possibly more) that tensor printing goes through:
|
|
# - tensor data can fit comfortably on screen
|
|
# - tensor data needs to be summarized
|
|
# Some of the codepaths don't fully support named tensors, so we send in
|
|
# an unnamed tensor to the formatting code as a workaround.
|
|
self = self.rename(None)
|
|
|
|
summarize = self.numel() > PRINT_OPTS.threshold
|
|
if self.dtype is torch.float16 or self.dtype is torch.bfloat16:
|
|
self = self.float()
|
|
|
|
if self.dtype.is_complex:
|
|
real_formatter = _Formatter(get_summarized_data(self.real) if summarize else self.real)
|
|
imag_formatter = _Formatter(get_summarized_data(self.imag) if summarize else self.imag)
|
|
return _tensor_str_with_formatter(self, indent, summarize, real_formatter, imag_formatter)
|
|
else:
|
|
formatter = _Formatter(get_summarized_data(self) if summarize else self)
|
|
return _tensor_str_with_formatter(self, indent, summarize, formatter)
|
|
|
|
def _add_suffixes(tensor_str, suffixes, indent, force_newline):
|
|
tensor_strs = [tensor_str]
|
|
last_line_len = len(tensor_str) - tensor_str.rfind('\n') + 1
|
|
for suffix in suffixes:
|
|
suffix_len = len(suffix)
|
|
if force_newline or last_line_len + suffix_len + 2 > PRINT_OPTS.linewidth:
|
|
tensor_strs.append(',\n' + ' ' * indent + suffix)
|
|
last_line_len = indent + suffix_len
|
|
force_newline = False
|
|
else:
|
|
tensor_strs.append(', ' + suffix)
|
|
last_line_len += suffix_len + 2
|
|
tensor_strs.append(')')
|
|
return ''.join(tensor_strs)
|
|
|
|
|
|
def get_summarized_data(self):
|
|
dim = self.dim()
|
|
if dim == 0:
|
|
return self
|
|
if dim == 1:
|
|
if self.size(0) > 2 * PRINT_OPTS.edgeitems:
|
|
return torch.cat((self[:PRINT_OPTS.edgeitems], self[-PRINT_OPTS.edgeitems:]))
|
|
else:
|
|
return self
|
|
if self.size(0) > 2 * PRINT_OPTS.edgeitems:
|
|
start = [self[i] for i in range(0, PRINT_OPTS.edgeitems)]
|
|
end = ([self[i]
|
|
for i in range(len(self) - PRINT_OPTS.edgeitems, len(self))])
|
|
return torch.stack([get_summarized_data(x) for x in (start + end)])
|
|
else:
|
|
return torch.stack([get_summarized_data(x) for x in self])
|
|
|
|
def _str_intern(inp):
|
|
prefix = 'tensor('
|
|
indent = len(prefix)
|
|
suffixes = []
|
|
|
|
# This is used to extract the primal value and thus disable the forward AD
|
|
# within this function.
|
|
# TODO(albanD) This needs to be updated when more than one level is supported
|
|
self, tangent = torch.autograd.forward_ad.unpack_dual(inp)
|
|
|
|
# Note [Print tensor device]:
|
|
# A general logic here is we only print device when it doesn't match
|
|
# the device specified in default tensor type.
|
|
# Currently torch.set_default_tensor_type() only supports CPU/CUDA, thus
|
|
# torch._C._get_default_device() only returns either cpu or cuda.
|
|
# In other cases, we don't have a way to set them as default yet,
|
|
# and we should always print out device for them.
|
|
if self.device.type != torch._C._get_default_device()\
|
|
or (self.device.type == 'cuda' and torch.cuda.current_device() != self.device.index):
|
|
suffixes.append('device=\'' + str(self.device) + '\'')
|
|
|
|
# TODO: add an API to map real -> complex dtypes
|
|
_default_complex_dtype = torch.cdouble if torch.get_default_dtype() == torch.double else torch.cfloat
|
|
has_default_dtype = self.dtype in (torch.get_default_dtype(), _default_complex_dtype, torch.int64, torch.bool)
|
|
if self.is_sparse:
|
|
suffixes.append('size=' + str(tuple(self.shape)))
|
|
suffixes.append('nnz=' + str(self._nnz()))
|
|
if not has_default_dtype:
|
|
suffixes.append('dtype=' + str(self.dtype))
|
|
indices_prefix = 'indices=tensor('
|
|
indices = self._indices().detach()
|
|
indices_str = _tensor_str(indices, indent + len(indices_prefix))
|
|
if indices.numel() == 0:
|
|
indices_str += ', size=' + str(tuple(indices.shape))
|
|
values_prefix = 'values=tensor('
|
|
values = self._values().detach()
|
|
values_str = _tensor_str(values, indent + len(values_prefix))
|
|
if values.numel() == 0:
|
|
values_str += ', size=' + str(tuple(values.shape))
|
|
tensor_str = indices_prefix + indices_str + '),\n' + ' ' * indent + values_prefix + values_str + ')'
|
|
elif self.is_quantized:
|
|
suffixes.append('size=' + str(tuple(self.shape)))
|
|
if not has_default_dtype:
|
|
suffixes.append('dtype=' + str(self.dtype))
|
|
suffixes.append('quantization_scheme=' + str(self.qscheme()))
|
|
if self.qscheme() == torch.per_tensor_affine or self.qscheme() == torch.per_tensor_symmetric:
|
|
suffixes.append('scale=' + str(self.q_scale()))
|
|
suffixes.append('zero_point=' + str(self.q_zero_point()))
|
|
elif self.qscheme() == torch.per_channel_affine or self.qscheme() == torch.per_channel_symmetric \
|
|
or self.qscheme() == torch.per_channel_affine_float_qparams:
|
|
suffixes.append('scale=' + str(self.q_per_channel_scales()))
|
|
suffixes.append('zero_point=' + str(self.q_per_channel_zero_points()))
|
|
suffixes.append('axis=' + str(self.q_per_channel_axis()))
|
|
tensor_str = _tensor_str(self.dequantize(), indent)
|
|
else:
|
|
if self.is_meta:
|
|
suffixes.append('size=' + str(tuple(self.shape)))
|
|
if self.dtype != torch.get_default_dtype():
|
|
suffixes.append('dtype=' + str(self.dtype))
|
|
# TODO: This implies that ellipses is valid syntax for allocating
|
|
# a meta tensor, which it could be, but it isn't right now
|
|
tensor_str = '...'
|
|
else:
|
|
if self.numel() == 0 and not self.is_sparse:
|
|
# Explicitly print the shape if it is not (0,), to match NumPy behavior
|
|
if self.dim() != 1:
|
|
suffixes.append('size=' + str(tuple(self.shape)))
|
|
|
|
# In an empty tensor, there are no elements to infer if the dtype
|
|
# should be int64, so it must be shown explicitly.
|
|
if self.dtype != torch.get_default_dtype():
|
|
suffixes.append('dtype=' + str(self.dtype))
|
|
tensor_str = '[]'
|
|
else:
|
|
if not has_default_dtype:
|
|
suffixes.append('dtype=' + str(self.dtype))
|
|
|
|
if self.layout != torch.strided:
|
|
tensor_str = _tensor_str(self.to_dense(), indent)
|
|
else:
|
|
tensor_str = _tensor_str(self, indent)
|
|
|
|
if self.layout != torch.strided:
|
|
suffixes.append('layout=' + str(self.layout))
|
|
|
|
# Use inp here to get the original grad_fn and not the one generated by the forward grad
|
|
# unpacking.
|
|
if inp.grad_fn is not None:
|
|
name = type(inp.grad_fn).__name__
|
|
if name == 'CppFunction':
|
|
name = inp.grad_fn.name().rsplit('::', 1)[-1]
|
|
suffixes.append('grad_fn=<{}>'.format(name))
|
|
elif inp.requires_grad:
|
|
suffixes.append('requires_grad=True')
|
|
|
|
if self.has_names():
|
|
suffixes.append('names={}'.format(self.names))
|
|
|
|
if tangent is not None:
|
|
suffixes.append('tangent={}'.format(tangent))
|
|
|
|
return _add_suffixes(prefix + tensor_str, suffixes, indent, force_newline=self.is_sparse)
|
|
|
|
def _str(self):
|
|
with torch.no_grad():
|
|
return _str_intern(self)
|