mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[functorch] modified example
This commit is contained in:
@ -105,20 +105,21 @@ When composed with `vmap`, `grad` can be used to compute per-sample-gradients:
|
|||||||
```py
|
```py
|
||||||
>>> from functorch import vmap
|
>>> from functorch import vmap
|
||||||
>>> batch_size, feature_size = 3, 5
|
>>> batch_size, feature_size = 3, 5
|
||||||
>>> weights = torch.randn(feature_size, requires_grad=True)
|
|
||||||
>>>
|
>>>
|
||||||
>>> def model(feature_vec):
|
>>> def model(weights,feature_vec):
|
||||||
>>> # Very simple linear model with activation
|
>>> # Very simple linear model with activation
|
||||||
>>> assert feature_vec.dim() == 1
|
>>> assert feature_vec.dim() == 1
|
||||||
>>> return feature_vec.dot(weights).relu()
|
>>> return feature_vec.dot(weights).relu()
|
||||||
>>>
|
>>>
|
||||||
>>> def compute_loss(weights, example, target):
|
>>> def compute_loss(weights, example, target):
|
||||||
>>> y = model(example)
|
>>> y = model(weights, example)
|
||||||
>>> return ((y - t) ** 2).mean() # MSELoss
|
>>> return ((y - target) ** 2).mean() # MSELoss
|
||||||
>>>
|
>>>
|
||||||
|
>>> weights = torch.randn(feature_size, requires_grad=True)
|
||||||
>>> examples = torch.randn(batch_size, feature_size)
|
>>> examples = torch.randn(batch_size, feature_size)
|
||||||
>>> targets = torch.randn(batch_size)
|
>>> targets = torch.randn(batch_size)
|
||||||
>>> grad_weight_per_example = vmap(grad(compute_loss))(weights, examples, targets)
|
>>> inputs = (weights,examples, targets)
|
||||||
|
>>> grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs)
|
||||||
```
|
```
|
||||||
|
|
||||||
### vjp and jacrev
|
### vjp and jacrev
|
||||||
@ -149,8 +150,8 @@ batched jacobians:
|
|||||||
`jacrev` can be composed with itself to produce hessians:
|
`jacrev` can be composed with itself to produce hessians:
|
||||||
```py
|
```py
|
||||||
>>> def f(x):
|
>>> def f(x):
|
||||||
>>> return x.sin().sum()
|
>>> return x.sin().sum()
|
||||||
>>>
|
>>>
|
||||||
>>> x = torch.randn(5)
|
>>> x = torch.randn(5)
|
||||||
>>> hessian = jacrev(jacrev(f))(x)
|
>>> hessian = jacrev(jacrev(f))(x)
|
||||||
```
|
```
|
||||||
|
Reference in New Issue
Block a user