Since the functional autograd + compiled autograd migration, we don't trace into nodes anymore, and everything is lifted. We can't support this flag which tries to inline make_fx style in CA initial pass. There's no more usage internally.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146720
Approved by: https://github.com/zou3519
This PR is on the way to getting compiled autograd's initial capture to
stop specializing on Tensor metadata.
This PR changes compiled autograd's initial capture to proxy an opaque
(w.r.t. Dynamo) function into the graph for all built-in codegen'ed
autograd nodes and validate_outputs.
We changed each codegen'ed apply_with_saved (e.g.
MulBackward0::apply_with_saved) to call into Python to proxy a function
(compiled_autograd.ops.MulBackward0) into the graph. Then, we use the
node's InputMetadata to "guess" at the properties of the output Tensors
to create some new FakeTensors.
Some details:
- MulBackward0::apply_with_saved lives in libtorch_cpu, but needs to be
call to Python via libtorch_python. There is an indirection
(PyCompilerInterface) to do this.
- MulBackward0::apply_with_saved passes a C++ function to Python. To make
our lives easier, every codegen'ed apply_with_saved passes a C++
function with the same signature
`(variable_list, ivalue_list) -> variable_list`.
- We define how to pack arbitrary C++ types into IValue via a helper
IValuePacker struct and codegen functional variants of each builtin
C++ autograd node (e.g. MulBackward0_apply_functional_ivalue).
MulBackward0 before this PR:
https://gist.github.com/zou3519/a80381d5fa38e970e413fcd91b0530de
MulBackward0 after this PR:
https://gist.github.com/zou3519/0c2eee8b3d8d96232b51ef430b53c5b0
Test Plan:
- existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143296
Approved by: https://github.com/jansel
- `FakeContext` hides all fields other than ctx.saved_tensors, this dynamo errors when the autograd.Function.backward uses other attrs on ctx and it also doesn't allow fallback to eager.
- If we remove it, we still can't fallback to eager: node variables are already freed (ctx.saved_tensors throws)
- However, we can fallback to "pseudo-eager" by using a duck-typed ctx and routing the ctx.saved_tensors to lifted tensors
- Dynamo tries to inline external_utils.call_backward, treats BackwardCFunction as a AutogradFunctionContextVariable (only used up until we create the fake context: FakeBackwardCFunction)
- we call_function backward from the forward class AutogradFunctionVariable, and we still pass in the fake context as a UserDefinedObjectVariable (can later use AutogradFunctionContextVariable + HOO graph speculate)
Fixes#125489#124827
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125661
Approved by: https://github.com/jansel
This adds support for backwards hooks that are *both*:
1) Interior to the graph; and
2) Dynamically generated (e.g. lambdas)
We do this by creating a BackwardState object that is used to register the hooks in the forward, then populated by dynamo *after* the forwards runs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120382
Approved by: https://github.com/xmfan
This PR adds support for torch.autograd.Function subclasses in compiled autograd. We do this by:
- Creating a uid for all torch.autograd.Function via its metaclass. This uid is used in the compiled autograd key, which is a subset of the cache key to the compiled graph
- "Lifting" the backward/saved_tensors, having them as input arguments in the compiled graph
- Creating proxies to track the backward's inputs and outputs. Since the backward's outputs (grads) have to match the forward's inputs, we pass the node's `input_info` (forward's input sizes) to build the proxies tracking the backward's outputs.
- Use a `FakeContext` class as a replacement for the autograd node's context object (`BackwardCFunction`) during tracing, only support passing saved_tensors from the forward to the backward
- Index each backward, to support multiple torch.autograd.Functions in the same graph
- Special case for `CompiledFunctionBackward`, lifting CompiledFunction will fail 4 tests and requires some skipfiles changes that I'd rather do that in a separate PR
Example graph: test_custom_fn_saved_multiple_tensors (eager fw + compiled autograd)
```python
class MyFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x, y):
ctx.save_for_backward(x, y)
return torch.sin(x), torch.sin(y)
@staticmethod
def backward(ctx, gO_x, gO_y):
(x, y) = ctx.saved_tensors
return gO_x * torch.cos(x), gO_y * torch.cos(y)
```
The backwards is lifted via `getitem_5` and `call_backward`
```python
# Compiled autograd graph
===== Compiled autograd graph =====
<eval_with_key>.0 class CompiledAutograd(torch.nn.Module):
def forward(self, inputs, sizes, hooks):
# No stacktrace found for following nodes
getitem: "f32[]" = inputs[0]
getitem_1: "f32[10]" = inputs[1]
getitem_2: "f32[10]" = inputs[2]
getitem_3: "f32[10]" = inputs[3]
getitem_4: "f32[10]" = inputs[4]; inputs = None
expand: "f32[10]" = torch.ops.aten.expand.default(getitem, [10]); getitem = None
mul: "f32[10]" = torch.ops.aten.mul.Tensor(expand, getitem_2); getitem_2 = None
mul_1: "f32[10]" = torch.ops.aten.mul.Tensor(expand, getitem_1); expand = getitem_1 = None
getitem_5 = hooks[0]; hooks = None
call_backward = torch__dynamo_external_utils_call_backward(getitem_5, (getitem_3, getitem_4), mul_1, mul); getitem_5 = mul_1 = mul = None
getitem_6: "f32[10]" = call_backward[0]
getitem_7: "f32[10]" = call_backward[1]; call_backward = None
accumulate_grad_ = torch.ops.inductor.accumulate_grad_.default(getitem_4, getitem_7); getitem_4 = getitem_7 = None
accumulate_grad__1 = torch.ops.inductor.accumulate_grad_.default(getitem_3, getitem_6); getitem_3 = getitem_6 = None
return []
```
then is later inlined by dynamo
```python
# Dynamo graph
===== __compiled_fn_0 =====
<eval_with_key>.1 class GraphModule(torch.nn.Module):
def forward(self, L_inputs_0_ : torch.Tensor, L_inputs_1_ : torch.Tensor, L_inputs_2_ : torch.Tensor, L_inputs_3_ : torch.Tensor, L_inputs_4_ : torch.Tensor):
getitem = L_inputs_0_
getitem_1 = L_inputs_1_
getitem_2 = L_inputs_2_
x = L_inputs_3_
y = L_inputs_4_
# File: <eval_with_key>.0:10, code: expand = torch.ops.aten.expand.default(getitem, [10]); getitem = None
expand = torch.ops.aten.expand.default(getitem, [10]); getitem = None
# File: <eval_with_key>.0:11, code: mul = torch.ops.aten.mul.Tensor(expand, getitem_2); getitem_2 = None
mul = torch.ops.aten.mul.Tensor(expand, getitem_2); getitem_2 = None
# File: <eval_with_key>.0:12, code: mul_1 = torch.ops.aten.mul.Tensor(expand, getitem_1); expand = getitem_1 = None
mul_1 = torch.ops.aten.mul.Tensor(expand, getitem_1); expand = getitem_1 = None
# File: /data/users/xmfan/core/pytorch/test/inductor/test_compiled_autograd.py:412, code: return gO_x * torch.cos(x), gO_y * torch.cos(y)
cos = torch.cos(x)
getitem_6 = mul_1 * cos; mul_1 = cos = None
cos_1 = torch.cos(y)
getitem_7 = mul * cos_1; mul = cos_1 = None
# File: <eval_with_key>.0:17, code: accumulate_grad_ = torch.ops.inductor.accumulate_grad_.default(getitem_4, getitem_7); getitem_4 = getitem_7 = None
accumulate_grad__default = torch.ops.inductor.accumulate_grad_.default(y, getitem_7); y = getitem_7 = None
# File: <eval_with_key>.0:18, code: accumulate_grad__1 = torch.ops.inductor.accumulate_grad_.default(getitem_3, getitem_6); getitem_3 = getitem_6 = None
accumulate_grad__default_1 = torch.ops.inductor.accumulate_grad_.default(x, getitem_6); x = getitem_6 = None
return ()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115573
Approved by: https://github.com/jansel
This can be useful for advanced users (like AOTAutograd) who don't want to keep the corresponding Tensor alive (for memory reasons for example) or when inplace op will change the Tensor's grad_fn (but gradients wrt to the original value is needed).
I went minimal API change but open to suggestions.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110867
Approved by: https://github.com/soulitzer
This branch:
1) converts the autograd tape into an FX graph
2) caches that conversion using a "shadow" graph
3) compiles and runs the generated FX graph instead of the normal autograd
What works currently:
1) Caching, capture, and initial integration
2) Backwards hooks
3) Inlining AotAutograd generated subgraphs
4) torch.compiling the generated FX graph
5) Auto-detecting dynamic shapes based on changes
Future work
1) Larger scale testing
1) Boxed calling convention, so memory can be freed incrementally
1) Support hooks on SavedTensor
1) Additional testing by running eager autograd tests under compiled_autograd.enable()
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103822
Approved by: https://github.com/ezyang, https://github.com/albanD
Fixes https://github.com/pytorch/pytorch/issues/104272
This PR adds a new private API `materialize_non_diff_grads` (default True) such that when set to False, grad outputs corresponding to outputs marked non-differentiable would receive None instead of a zero-filled tensor. This is overrides the setting of `materialize_grads`, i.e. grad outputs corresponding non-differentiable outputs would still be None even if `materialize_grads=True` (the default).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104291
Approved by: https://github.com/albanD
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71569
Not sure if this is the right API
Test Plan: Imported from OSS
Reviewed By: albanD
Differential Revision: D33695395
Pulled By: soulitzer
fbshipit-source-id: 652b5758f15d901f98ff0da94e977030c7f3415b
(cherry picked from commit 9421a6846ad35cebbb84bd052769527505092a0c)
Summary:
As GoogleTest `TEST` macro is non-compliant with it as well as `DEFINE_DISPATCH`
All changes but the ones to `.clang-tidy` are generated using following script:
```
for i in `find . -type f -iname "*.c*" -or -iname "*.h"|xargs grep cppcoreguidelines-avoid-non-const-global-variables|cut -f1 -d:|sort|uniq`; do sed -i "/\/\/ NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)/d" $i; done
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62008
Reviewed By: driazati, r-barnes
Differential Revision: D29838584
Pulled By: malfet
fbshipit-source-id: 1b2f8602c945bd4ce50a9bfdd204755556e31d13
Summary:
This PR suppresses clang-tidy warnings in the codebase (for now) so that we can re-enable clang-tidy checks on master.
I ran this script to add the `NOLINTNEXTLINE` comments (on a devserver):
```bash
python3 setup.py develop
# Uses same script that's run on CI and adds the -j (parallel), -s (add comments), -k (continue if diagnostic errors are found) options
python3 tools/clang_tidy.py \
-j \
-s \
-k \
-v \
--paths torch/csrc/ \
-g"-torch/csrc/jit/passes/onnx/helper.cpp" \
-g"-torch/csrc/jit/passes/onnx/shape_type_inference.cpp" \
-g"-torch/csrc/jit/serialization/onnx.cpp" \
-g"-torch/csrc/jit/serialization/export.cpp" \
-g"-torch/csrc/jit/serialization/import.cpp" \
-g"-torch/csrc/jit/serialization/import_legacy.cpp" \
-g"-torch/csrc/onnx/init.cpp" \
-g"-torch/csrc/cuda/nccl.*" \
-g"-torch/csrc/cuda/python_nccl.cpp" \
-g"-torch/csrc/autograd/FunctionsManual.cpp" \
-g"-torch/csrc/generic/*.cpp" \
-g"-torch/csrc/jit/codegen/cuda/runtime/*" \
-g"-torch/csrc/deploy/interpreter/interpreter.cpp" \
-g"-torch/csrc/deploy/interpreter/interpreter.h" \
-g"-torch/csrc/deploy/interpreter/interpreter_impl.h" \
-g"-torch/csrc/deploy/interpreter/test_main.cpp"
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60649
Test Plan: Verified changes by re-running the script (without the `-s` option) and seeing no warnings/errors.
Reviewed By: walterddr, janeyx99
Differential Revision: D29504258
Pulled By: 1ntEgr8
fbshipit-source-id: 78310b30ee8213b73ddb4771ad874665323e7a4e
Summary:
Fixes https://github.com/pytorch/pytorch/issues/30696
### Release Notes
Instantiating a custom autograd function is now deprecated. Users should call `.apply()` on the class itself because it is a static method.
--end release notes--
- There are a couple error messages that we can't entirely remove because accessing these attributes of the autograd function instance may segfault (due to cdata being nullptr). Also added a TORCH_CHECK for the name attribute which previously segfaulted.
- Error message updated to convey 1) old-style functions have been deprecated 2) this access pattern was once valid
- Updates variable -> Tensor for some error messages
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57357
Reviewed By: mrshenli
Differential Revision: D28193095
Pulled By: soulitzer
fbshipit-source-id: f021b105e9a3fd4a20d6ee3dfb6a06a8c34b10ca
Summary:
Such a deadlock was found for PyFunctionPreHook after adding https://github.com/pytorch/pytorch/pull/57057
This is fixing all occurrences in torch/csrc/autograd
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57488
Reviewed By: malfet
Differential Revision: D28163321
Pulled By: albanD
fbshipit-source-id: 4daf1db69674e73967fc7c5ca2a240c61340e7ca
Summary:
This is an automatic change generated by the following script:
```
#!/usr/bin/env python3
from subprocess import check_output, check_call
import os
def get_compiled_files_list():
import json
with open("build/compile_commands.json") as f:
data = json.load(f)
files = [os.path.relpath(node['file']) for node in data]
for idx, fname in enumerate(files):
if fname.startswith('build/') and fname.endswith('.DEFAULT.cpp'):
files[idx] = fname[len('build/'):-len('.DEFAULT.cpp')]
return files
def run_clang_tidy(fname):
check_call(["python3", "tools/clang_tidy.py", "-c", "build", "-x", fname,"-s"])
changes = check_output(["git", "ls-files", "-m"])
if len(changes) == 0:
return
check_call(["git", "commit","--all", "-m", f"NOLINT stubs for {fname}"])
def main():
git_files = check_output(["git", "ls-files"]).decode("ascii").split("\n")
compiled_files = get_compiled_files_list()
for idx, fname in enumerate(git_files):
if fname not in compiled_files:
continue
if fname.startswith("caffe2/contrib/aten/"):
continue
print(f"[{idx}/{len(git_files)}] Processing {fname}")
run_clang_tidy(fname)
if __name__ == "__main__":
main()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56892
Reviewed By: H-Huang
Differential Revision: D27991944
Pulled By: malfet
fbshipit-source-id: 5415e1eb2c1b34319a4f03024bfaa087007d7179
Summary:
Context: https://github.com/pytorch/pytorch/pull/53299#discussion_r587882857
These are the only hand-written parts of this diff:
- the addition to `.github/workflows/lint.yml`
- the file endings changed in these four files (to appease FB-internal land-blocking lints):
- `GLOSSARY.md`
- `aten/src/ATen/core/op_registration/README.md`
- `scripts/README.md`
- `torch/csrc/jit/codegen/fuser/README.md`
The rest was generated by running this command (on macOS):
```
git grep -I -l ' $' -- . ':(exclude)**/contrib/**' ':(exclude)third_party' | xargs gsed -i 's/ *$//'
```
I looked over the auto-generated changes and didn't see anything that looked problematic.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/53406
Test Plan:
This run (after adding the lint but before removing existing trailing spaces) failed:
- https://github.com/pytorch/pytorch/runs/2043032377
This run (on the tip of this PR) succeeded:
- https://github.com/pytorch/pytorch/runs/2043296348
Reviewed By: walterddr, seemethere
Differential Revision: D26856620
Pulled By: samestep
fbshipit-source-id: 3f0de7f7c2e4b0f1c089eac9b5085a58dd7e0d97
Summary:
Added a new option in AutogradContext to tell autograd to not materialize output grad tensors, that is, don't expand undefined/None tensors into tensors full of zeros before passing them as input to the backward function.
This PR is the second part that closes https://github.com/pytorch/pytorch/issues/41359. The first PR is https://github.com/pytorch/pytorch/pull/41490.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/41821
Reviewed By: albanD
Differential Revision: D22693163
Pulled By: heitorschueroff
fbshipit-source-id: a8d060405a17ab1280a8506a06a2bbd85cb86461
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34845
This PR allows PyNode to persist the error message so that any pure C++
thread that runs autograd with custom Python autograd function can successfully
catpure the error message without maintaining a initial PyThreadState.
Test Plan: Imported from OSS
Differential Revision: D20480685
Pulled By: wanchaol
fbshipit-source-id: 0488ea5a4df9a33b53ac5d0d59000c41ab6cb748
Summary:
Given that pybind11 implements these gil functions, I don't think it makes sense for Pytorch to have its own bespoke versions.
Fixes https://github.com/pytorch/pytorch/issues/29065
Pull Request resolved: https://github.com/pytorch/pytorch/pull/29095
Differential Revision: D18301806
Pulled By: ezyang
fbshipit-source-id: 03da6a26c41ee65aaadf7b67b9f0b14d2def2a5a
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/23572
### **(The stack from #23020 was moved into this PR)**
Adding API for custom autograd operations, with user defined forward and backward, [like in python](https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd).
The custom operation should be a subclass of Function, with static forward and backward functions. `forward()` can accept any arguments similar to the Python API and `backward()` should accept a variable list as an argument.
Both `forward()` and `backward() `accept a AutogradContext* which can be used to share data between them.
Variables can be saved in the context using `save_for_backward()` and other data can be saved in the map `save` in the form of `<std::string, at::IValue>` pairs. Variables saved in forward can be accessed with `get_saved_variables()`.
Example usage:
```
class MyFunction : public Function<MyFunction> {
public:
static variable_list forward(AutogradContext *ctx, int n, Variable var) {
// Save data for backward in context
ctx->saved_data["n"] = n;
return {var};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
// Use data saved in forward
auto n = ctx->saved_data["n"].toInt();
return {grad_output[0]*n};
}
};
```
Then, it can be used with:
```
Variable x;
MyFunction::apply(6, x);
```
Also AutogradContext has methods to mark outputs as non differentiable and mark inputs as dirty similar to the [Python API](ff23a02ac4/torch/autograd/function.py (L26)).
Test Plan: Added tests for the custom autograd function API based on test_autograd.py. Currently only the tests for the basic functionality have been added. More tests will be added later.
Differential Revision: D16583428
fbshipit-source-id: 0bd42f19ce37bcd99d3080d16195ad74d40d0413
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15316
This starts cleaning up the files in c10 according to the module structure we decided on.
Move to c10/util:
- Half.h, Half-inl.h, Half.cpp, bitcasts.h
Move to c10/core:
- Device.h, Device.cpp
- DeviceType.h, DeviceType.cpp
i-am-not-moving-c2-to-c10
Reviewed By: dzhulgakov
Differential Revision: D13498493
fbshipit-source-id: dfcf1c490474a12ab950c72ca686b8ad86428f63
Summary:
Anywhere we used #include "foo.h", we now say #include <foo.h>
Paths are adjusted to be rooted out of aten/src, torch/lib, or
the root level directory.
I modified CMakeLists.txt by hand to remove TH and THC from
the include paths.
I used the following script to do the canonicalization:
```
import subprocess
import re
import os.path
files = subprocess.check_output(['git', 'ls-files']).decode('utf-8').rstrip().split('\n')
for fn in files:
if not any(fn.endswith(suff) for suff in ['.cu', '.cpp', '.in', '.h', '.hpp', '.cu', '.cuh', '.cc']):
continue
if not any(fn.startswith(pref) for pref in ["aten/", "torch/"]):
continue
with open(fn, 'r') as f:
c = f.read()
def fmt(p):
return "#include <{}>".format(p)
def repl(m):
p = m.group(1)
if p in ["dlfcn.h", "unistd.h", "nvrtc.h", "cuda.h", "cuda_runtime.h", "cstdint", "cudnn.h", "Python.h", "cusparse.h", "cuda_runtime_api.h", "cuda_fp16.h", "cublas_v2.h", "stdint.h", "curand_kernel.h"]:
return fmt(p)
if any(p.startswith(pref) for pref in ["torch/csrc", "c10/", "ATen/", "caffe2/", "TH/", "THC/", "Eigen/", "gtest/", "zdl/", "gloo/", "onnx/", "miopen/"]):
return fmt(p)
for root in ["aten/src", "torch/lib", ""]:
for bad_root in [os.path.dirname(fn), "aten/src/TH", "aten/src/THC", "torch/csrc"]:
new_p = os.path.relpath(os.path.join(bad_root, p), root)
if not new_p.startswith("../") and (os.path.exists(os.path.join(root, new_p)) or os.path.exists(os.path.join(root, new_p + ".in"))):
return fmt(new_p)
print("ERROR: ", fn, p)
return m.group(0)
new_c = re.sub(r'#include "([^"]+)"', repl, c)
if new_c != c:
print(fn)
with open(fn, 'w') as f:
f.write(new_c)
```
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14849
Reviewed By: dzhulgakov
Differential Revision: D13363445
Pulled By: ezyang
fbshipit-source-id: 52361f878a672785f9306c9e9ab2513128092b68
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13342
This PR introduces a few new concepts:
- DeviceGuardImplInterface, and implementations for CPU and CUDA, which
provide a generic interface for interfacing with device and stream state,
without requiring a direct dependency on the code in question.
- InlineDeviceGuard, a general template for generating both specialized
and dynamically dispatched device guard implementations. Dynamic
dispatch is done by specializing it on a VirtualGuardImpl.
- Provide a device-independent DeviceGuard class, which can be used even
from CPU code. It uses the aforementioned dynamic dispatch.
- CUDA-specialized CUDAGuard class, which doesn't have a dynamic dispatch
but can only be used from CUDA.
- StreamGuard, which is the same as above, but for streams rather than
devices.
- Optional variants of all the aforementioned guards, which are a no-op if
no device/stream is specified
- CUDAMultiStreamGuard, specifically for the case when we want to set
a device on every guard.
There are some subtle semantic changes, which have been thoroughly documented
in the class definition.
BC-breaking changes:
- Move constructor/assignment have been removed from all device guard
implementations.
- In some cases where you previously wrote 'set_device' (or 'set_stream'), you now must write
'reset_device', because if you switch devices/device types, the stream/device on the
previous device is unset. This is different from previous behavior.
- CUDAGuard no longer handles streams, or multiple streams. Use CUDAStreamGuard
or CUDAMultiStreamGuard as appropriate for your use case.
Reviewed By: dzhulgakov
Differential Revision: D12849620
fbshipit-source-id: f61956256f0b12be754b3234fcc73c2abc1be04e
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13275
This resulted in a bunch of knock-on changes, which I will now
describe:
- s/original_index/original_device/
- s/last_index/last_device/
- A bunch of places that used set_index, now use CUDAGuard (which does have
set_index) because they were CUDA-specific code.
Major caveat: DeviceGuard doesn't *actually* work non-CUDA/CPU devices, To make
that happen, I plan on totally replacing the implementation of DeviceGuard; what
I mostly care about here is wrangling the API into an acceptable state.
Reviewed By: gchanan
Differential Revision: D12832080
fbshipit-source-id: 7de068c7cec35663dc8a533026a626331336e61d