mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Updated Developer FAQ (markdown)
@ -42,6 +42,20 @@ Many function/feature proposals are interesting and derived from compelling rese
|
||||
|
||||
## UX
|
||||
|
||||
### What is type promotion and why do some PyTorch operations support it while others don't?
|
||||
|
||||
When we think of an operation in PyTorch we can think of it as having one or more inputs, a computation, and an output. Type promotion is when an input is converted to a "higher dtype" for use in an operation's computation. This is often necessary to define operations. For example, adding tensors of different dtypes requires selecting a "computation type" to add the tensors in, so when adding a float32 and a float64 tensor PyTorch computes the addition in float64. See https://pytorch.org/docs/master/tensor_attributes.html for more details on type promotion.
|
||||
|
||||
While type promotion can be convenient, it can also be confusing to users and a source of error. Therefore, we've decided to limit type promotion in PyTorch to the following classes of operations:
|
||||
|
||||
- unary pointwise operations where the codomain of the input cannot be represented in the input's dtype
|
||||
- binary pointwise operations
|
||||
- reductions where the codomain of the input cannot be represented in the input's dtype
|
||||
|
||||
Binary pointwise operations, like add, are relatively straightforward. An example of a unary pointwise operation with type promotion is sin, since sin(1) = ~0.8415, we see that the codomain of the integers is not the integers or, to say more plainly, integer inputs produce non-integer values. If operations like sin didn't type promote their inputs then they would just be undefined on integer tensors. Similarly reductions like mean can be defined on integer tensors by type promoting them to a floating point type, since the mean of a tensor of integers may not be an integer.
|
||||
|
||||
Beyond these three classes type promotion becomes trickier, less intuitive, and more error prone. While some operations outside of these classes implement type promotion today, PyTorch's current plan is only to add type promotion support to unary pointwise, binary pointwise, and reduction operations. Other operations supporting type promotion will not be changed to support backwards compatibility.
|
||||
|
||||
### How does out= work in PyTorch?
|
||||
|
||||
When a user passes one or more tensors to out= the contract is as follows:
|
||||
@ -52,7 +66,7 @@ When a user passes one or more tensors to out= the contract is as follows:
|
||||
|
||||
A "safe copy" is different from PyTorch's regular copy. For operations that do not participate in type promotion the device and dtype of the source and destination tensors must match. For operations that do participate in type promotion the copy can be to a different dtype, but the destination of the copy cannot be a lower "type kind" than the source. PyTorch has four type kinds: boolean, integer, float, and complex, in that order. So, for example, an operation like add (which participates in type promotion) will throw a runtime error if given float inputs but an integer out= tensor.
|
||||
|
||||
Note that while the numerics of out= are that the operation is performed and then its results are "safe copied," behind the scenes operations may reuse the storage of out= tensors and fuse the copy for efficiency. Many operations, like add, perform these optimizations.
|
||||
Note that while the numerics of out= are that the operation is performed and then its results are "safe copied," behind the scenes operations may reuse the storage of out= tensors and fuse the copy for efficiency. Many operations, like add, perform these optimizations. Also, while PyTorch's "out= contract" is specified above, many operations in PyTorch do not correctly implement the contract and need to be updated.
|
||||
|
||||
### How do in place operations work in PyTorch?
|
||||
|
||||
|
Reference in New Issue
Block a user