[functorch] modified example

This commit is contained in:
Horace He
2021-04-28 16:10:11 -07:00
committed by Jon Janzen
parent c24314c09b
commit ac9be17a87

View File

@ -105,20 +105,21 @@ When composed with `vmap`, `grad` can be used to compute per-sample-gradients:
```py
>>> from functorch import vmap
>>> batch_size, feature_size = 3, 5
>>> weights = torch.randn(feature_size, requires_grad=True)
>>>
>>> def model(feature_vec):
>>> def model(weights,feature_vec):
>>> # Very simple linear model with activation
>>> assert feature_vec.dim() == 1
>>> return feature_vec.dot(weights).relu()
>>>
>>> def compute_loss(weights, example, target):
>>> y = model(example)
>>> return ((y - t) ** 2).mean() # MSELoss
>>> y = model(weights, example)
>>> return ((y - target) ** 2).mean() # MSELoss
>>>
>>> weights = torch.randn(feature_size, requires_grad=True)
>>> examples = torch.randn(batch_size, feature_size)
>>> targets = torch.randn(batch_size)
>>> grad_weight_per_example = vmap(grad(compute_loss))(weights, examples, targets)
>>> inputs = (weights,examples, targets)
>>> grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs)
```
### vjp and jacrev
@ -149,8 +150,8 @@ batched jacobians:
`jacrev` can be composed with itself to produce hessians:
```py
>>> def f(x):
>>> return x.sin().sum()
>>>
>>> return x.sin().sum()
>>>
>>> x = torch.randn(5)
>>> hessian = jacrev(jacrev(f))(x)
```