[BE] fix typos in functorch/ and scripts/ (#156081)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156081
Approved by: https://github.com/albanD
ghstack dependencies: #156080
This commit is contained in:
Xuehai Pan
2025-06-18 23:58:49 +08:00
committed by PyTorch MergeBot
parent 2ccfd14e23
commit e3507c3777
17 changed files with 30 additions and 29 deletions

View File

@ -710,7 +710,7 @@ public:
auto t = Tensor::wrap(run_torch_function(A, delayed_->orig, delayed_->args, true));
tensor_ = t->tensor(A);
delayed_.reset();
// don't force creation of batch tensor if it wasn't alreay provided.
// don't force creation of batch tensor if it wasn't already provided.
batchtensor_ = t->batchtensor_;
AT_ASSERT(levels() == t->levels());
}
@ -1739,7 +1739,7 @@ static mpy::object dot(Arena& A, TensorInfo lhs, TensorInfo rhs, Slice<DimEntry>
if (lr_dims.dims.size() != sum.size()) {
for (auto & d : sum) {
if (!lhs.levels.contains(d) && !rhs.levels.contains(d)) {
mpy::raise_error(DimensionBindError(), "summing over non-existant dimension %S", d.dim().ptr());
mpy::raise_error(DimensionBindError(), "summing over non-existent dimension %S", d.dim().ptr());
}
}
}
@ -2206,7 +2206,7 @@ mpy::object index(Arena& A, mpy::handle self, mpy::handle dims, mpy::handle indi
self_info.tensor = A.autorelease(rearranged->reshape(at::IntArrayRef(new_sizes.begin(), new_sizes.end())));
self_info.levels = reshape_levels; // note: we are using the first level in a flattened group to represent the group for the rest of the op
// we need to be careful not to rely the dimensions size because it doesnt match the size of the whole group
// we need to be careful not to rely the dimensions size because it doesn't match the size of the whole group
}
bool has_dimpacks = false;
for (auto idx : indices_list) {
@ -2219,7 +2219,7 @@ mpy::object index(Arena& A, mpy::handle self, mpy::handle dims, mpy::handle indi
return invoke_getitem(A, info);
}
// true -- the indices were flattend out of a tuple, list or sequence...
// true -- the indices were flattened out of a tuple, list or sequence...
Slice<mpy::handle> slice_from_sequence(Arena& A, mpy::handle value) {
if (mpy::tuple_view::check(value)) {
@ -2539,7 +2539,7 @@ IndexingInfo getsetitem_flat(Arena& A, TensorInfo self_info, Slice<mpy::handle>
}
} else if (Dim::check_exact(inp)) {
auto d = Dim::unchecked_wrap(inp);
// dimesions used once are just binding operations
// dimensions used once are just binding operations
if (1 == seen_dims_nuses[*seen_dims.index(d)]) {
flat_inputs[i] = no_slice;
result_levels.append(A, d);
@ -2798,7 +2798,7 @@ PyObject* py_split(PyObject *_,
if (!dim.ptr()) {
dim = A.autorelease(mpy::from_int(0));
}
mpy::raise_error(PyExc_TypeError, "tensor does not comtain dimension %R", dim.ptr());
mpy::raise_error(PyExc_TypeError, "tensor does not contain dimension %R", dim.ptr());
}
Slice<int64_t> indices;

View File

@ -6,7 +6,7 @@
#pragma once
// note: pytorch's python variable simple includes pybind which conflicts with minpybind
// so this file just reproduces the minimial API needed to extract Tensors from python objects.
// so this file just reproduces the minimal API needed to extract Tensors from python objects.
#include <torch/csrc/python_headers.h>
#include <ATen/core/Tensor.h>

View File

@ -5,7 +5,7 @@ Named Tensors using First-class Dimensions in PyTorch
_An implementation of [named tensors](https://namedtensor.github.io) with the functionality of [einsum](http://einops.rocks]http://einops.rocks) , batching ([vmap](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap), [xmap](https://jax.readthedocs.io/en/latest/notebooks/xmap_tutorial.html)), and tensor indexing by adding dimension objects to PyTorch_.
The tensor input to a resnet might have the shape [8, 3, 224, 224] but informally we think of those dimensions as 'batch', 'channel', 'width', and 'height'. Eventhough 'width' and 'height' have the same _size_ we still think of them as separate dimensions, and if we have two _different_ images, we think of both as sharing the _same_ 'channel' dimension.
The tensor input to a resnet might have the shape [8, 3, 224, 224] but informally we think of those dimensions as 'batch', 'channel', 'width', and 'height'. Even though 'width' and 'height' have the same _size_ we still think of them as separate dimensions, and if we have two _different_ images, we think of both as sharing the _same_ 'channel' dimension.
Named tensors gives these dimensions names. [PyTorch's current implementation](https://pytorch.org/docs/stable/named_tensor.html) uses strings to name dimensions. Instead, this library introduces a Python object, a `Dim`, to represent the concept. By expanding the semantics of tensors with dim objects, in addition to naming dimensions, we can get behavior equivalent to batching transforms (xmap, vmap), einops-style rearrangement, and loop-style tensor indexing.
@ -751,7 +751,7 @@ In this way, first-class dims are a way of adapting the nicer syntax of these ar
Performance Expectations
========================
First-class dimensions are not a compiler. They provide syntax for existing PyTorch operations such as advanced indexing that is easier to read and write. For large sized tensors, the performance of any statements including them will be the same as using the already existing operations. An important exception is the pattern matching of products and summation, where performance will be improved by issuing to a matrix-multiply kernel. The C++ implementation of dimensions adds a small overhead of around 2us on top of PyTorch's normal overhead of 8us to each function that uses them. In the future, the implementation can encorporate more fusion optimization to further improve performance of this style of code.
First-class dimensions are not a compiler. They provide syntax for existing PyTorch operations such as advanced indexing that is easier to read and write. For large sized tensors, the performance of any statements including them will be the same as using the already existing operations. An important exception is the pattern matching of products and summation, where performance will be improved by issuing to a matrix-multiply kernel. The C++ implementation of dimensions adds a small overhead of around 2us on top of PyTorch's normal overhead of 8us to each function that uses them. In the future, the implementation can incorporate more fusion optimization to further improve performance of this style of code.
## License

View File

@ -58,7 +58,7 @@ TensorLike = (_Tensor, torch.Tensor)
class Dim(_C.Dim, _Tensor):
# note that _C.Dim comes before tensor because we want the Dim API for things like size to take precendence.
# note that _C.Dim comes before tensor because we want the Dim API for things like size to take precedence.
# Tensor defines format, but we want to print Dims with special formatting
__format__ = object.__format__

View File

@ -507,7 +507,7 @@ def t__getitem__(self, input):
for i in reversed(dim_packs):
input[i : i + 1] = input[i]
# currenty:
# currently:
# input is flat, containing either Dim, or Tensor, or something valid for standard indexing
# self may have first-class dims as well.
@ -515,7 +515,7 @@ def t__getitem__(self, input):
# drop the first class dims from self, they just become direct indices of their positions
# figure out the dimensions of the indexing tensors: union of all the dims in the tensors in the index.
# these dimensions will appear and need to be bound at the first place tensor occures
# these dimensions will appear and need to be bound at the first place tensor occurs
if isinstance(self, _Tensor):
ptensor_self, levels = self._tensor, list(self._levels)

View File

@ -138,7 +138,7 @@ step6()
# Step 7: Now, the flaw with step 6 is that we were training on the same exact
# data. This can lead to all of the models in the ensemble overfitting in the
# same way. The solution that http://willwhitney.com/parallel-training-jax.html
# applies is to randomly subset the data in a way that the models do not recieve
# applies is to randomly subset the data in a way that the models do not receive
# exactly the same data in each training step!
# Because the goal of this doc is to show that we can use eager-mode vmap to
# achieve similar things as JAX, the rest of this is left as an exercise to the reader.

View File

@ -1,4 +1,4 @@
# This example was adapated from https://github.com/muhrin/milad
# This example was adapted from https://github.com/muhrin/milad
# It is licensed under the GLPv3 license. You can find a copy of it
# here: https://www.gnu.org/licenses/gpl-3.0.en.html .

View File

@ -100,7 +100,7 @@ ft_jac_weight, ft_jac_bias = jacrev(predict, argnums=(0, 1))(weight, bias, x)
# vjp and vmap transforms.
# - jacfwd uses forward-mode AD. It is implemented as a composition of our
# jvp and vmap transforms.
# jacfwd and jacrev can be subsituted for each other and have different
# jacfwd and jacrev can be substituted for each other and have different
# performance characteristics.
#
# As a general rule of thumb, if you're computing the jacobian of an R^N -> R^M

View File

@ -350,7 +350,7 @@
{
"cell_type": "markdown",
"source": [
"Furthemore, its pretty easy to flip the problem around and say we want to compute Jacobians of the parameters to our model (weight, bias) instead of the input."
"Furthermore, its pretty easy to flip the problem around and say we want to compute Jacobians of the parameters to our model (weight, bias) instead of the input."
],
"metadata": {
"id": "EQAB99EQflUJ"

View File

@ -123,7 +123,7 @@
"predictions = model(data) # move the entire mini-batch through the model\n",
"\n",
"loss = loss_fn(predictions, targets)\n",
"loss.backward() # back propogate the 'average' gradient of this mini-batch"
"loss.backward() # back propagate the 'average' gradient of this mini-batch"
],
"metadata": {
"id": "WYjMx8QTUvRu"