mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[functorch] modified example
This commit is contained in:
@ -105,20 +105,21 @@ When composed with `vmap`, `grad` can be used to compute per-sample-gradients:
|
||||
```py
|
||||
>>> from functorch import vmap
|
||||
>>> batch_size, feature_size = 3, 5
|
||||
>>> weights = torch.randn(feature_size, requires_grad=True)
|
||||
>>>
|
||||
>>> def model(feature_vec):
|
||||
>>> def model(weights,feature_vec):
|
||||
>>> # Very simple linear model with activation
|
||||
>>> assert feature_vec.dim() == 1
|
||||
>>> return feature_vec.dot(weights).relu()
|
||||
>>>
|
||||
>>> def compute_loss(weights, example, target):
|
||||
>>> y = model(example)
|
||||
>>> return ((y - t) ** 2).mean() # MSELoss
|
||||
>>> y = model(weights, example)
|
||||
>>> return ((y - target) ** 2).mean() # MSELoss
|
||||
>>>
|
||||
>>> weights = torch.randn(feature_size, requires_grad=True)
|
||||
>>> examples = torch.randn(batch_size, feature_size)
|
||||
>>> targets = torch.randn(batch_size)
|
||||
>>> grad_weight_per_example = vmap(grad(compute_loss))(weights, examples, targets)
|
||||
>>> inputs = (weights,examples, targets)
|
||||
>>> grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs)
|
||||
```
|
||||
|
||||
### vjp and jacrev
|
||||
@ -149,8 +150,8 @@ batched jacobians:
|
||||
`jacrev` can be composed with itself to produce hessians:
|
||||
```py
|
||||
>>> def f(x):
|
||||
>>> return x.sin().sum()
|
||||
>>>
|
||||
>>> return x.sin().sum()
|
||||
>>>
|
||||
>>> x = torch.randn(5)
|
||||
>>> hessian = jacrev(jacrev(f))(x)
|
||||
```
|
||||
|
Reference in New Issue
Block a user