mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
@ -7,7 +7,7 @@
|
||||
| [**Future Plans**](#future-plans)
|
||||
|
||||
**This library is currently under heavy development - if you have suggestions
|
||||
on the API or use-cases you'd like to be covered, please open an github issue
|
||||
on the API or use-cases you'd like to be covered, please open a GitHub issue
|
||||
or reach out. We'd love to hear about how you're using the library.**
|
||||
|
||||
`functorch` is [JAX-like](https://github.com/google/jax) composable function
|
||||
@ -161,7 +161,7 @@ result = vmap(model)(examples)
|
||||
|
||||
### grad
|
||||
|
||||
`grad(func)(*inputs)` assumes `func` returns a single-element Tensor. It compute
|
||||
`grad(func)(*inputs)` assumes `func` returns a single-element Tensor. It computes
|
||||
the gradients of the output of func w.r.t. to `inputs[0]`.
|
||||
|
||||
```py
|
||||
@ -192,7 +192,7 @@ def compute_loss(weights, example, target):
|
||||
weights = torch.randn(feature_size, requires_grad=True)
|
||||
examples = torch.randn(batch_size, feature_size)
|
||||
targets = torch.randn(batch_size)
|
||||
inputs = (weights,examples, targets)
|
||||
inputs = (weights, examples, targets)
|
||||
grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs)
|
||||
```
|
||||
|
||||
|
@ -5,7 +5,7 @@ First off, what are batching rules and why do we need so many of them? Well, to
|
||||
### How does vmap work?
|
||||
Vmap is a function transform (pioneered by Jax) that allows one to batch functions. That is, given a function `f(x: [N]) -> [N]`, `vmap(f)` now transforms the signature to be `f(x: [B, N]) -> [B, N]`. That is - it adds a batch dimension to both the input and the output of the function.
|
||||
|
||||
This guide will gloss over all the cool things you can do this (there are many!), so let's focus on how we actually implement this.
|
||||
This guide will gloss over all the cool things you can do with this (there are many!), so let's focus on how we actually implement this.
|
||||
|
||||
One misconception is that this is some magic compiler voodoo, or that it is inherently some function transform. It is not - and there's another framing of it that might make it more clear.
|
||||
|
||||
|
Reference in New Issue
Block a user