mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: This diff fixes two things which come up when testing a tgif-published pt2 model remote net: 1) Updates isSameDevice to handle meta device to avoid this error: ``` what(): Unsupported device typemeta and meta Exception raised from isSameDevice at fbcode/caffe2/torch/nativert/executor/PlacementUtils.cpp:20 ``` 2. Updates xl weight v2 loading logic in Weights.cpp to handle non-TBE xl-weights. Today, we enforce the device is the same for an old weight and new weight when replacing with ModelRunnerAdapter.setAttr(). However, the way we replace non-TBE xl weights is to find any weights on "meta" device and then replace them with their correct weight with real device from xl_weights folder. Therefore, the new weight and old weight will always have different devices and the device check is invalid. I don't think we've run into this so far bc non-TBE xl weights have not been thoroughly tested until now. Test Plan: Run MRS you model merge net, which uses non-TBE xl weights. Confirm that before change #1 we get error: ``` Unsupported device typemeta and meta ``` Then after change #1 and before change #2 we get: ``` what(): Mismatched device for merge.user_tower.linear.weight: meta vs cpu Exception raised from validateValue at fbcode/caffe2/torch/nativert/executor/Weights.cpp:374 ``` After change run is successful Command: ``` MODEL_ENTITY_ID=921242082 SNAPSHOT_ID=1269 module_name=merge SAMPLE_INPUT_DIR=/data/users/georgiaphillips/models/921242082/${SNAPSHOT_ID}/${module_name}_archive/package/data/sample_inputs buck2 run mode/dev-nosan -c fbcode.nvcc_arch=h100,a100 -c fbcode.enable_gpu_sections=true caffe2/torch/fb/model_transform/fx2trt/packaging:load_net_predictor -- --loadMode=Benchmark --inputNetFile=/data/users/$USER/models/${MODEL_ENTITY_ID}/${SNAPSHOT_ID}/${MODEL_ENTITY_ID}_${SNAPSHOT_ID}.predictor.${module_name} --moduleName=${module_name} --submodToDevice="merge|cuda0" --benchmarkEnableProfiling=false --disableStaticRuntime=true --doNotRandomizeSampleInputs=true --benchmarkDontRebatchSamples=true --pytorch_predictor_sigmoid_static_dispatch_enable=false --pytorch_predictor_sigmoid_graph_passes_enable=false --sampleInputFilePath=${SAMPLE_INPUT_DIR}/${module_name}.pt ``` Rollback Plan: Differential Revision: D80713052 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162842 Approved by: https://github.com/henryoier
27 lines
596 B
C++
27 lines
596 B
C++
#include <torch/nativert/executor/Placement.h>
|
|
|
|
#include <fmt/ostream.h>
|
|
|
|
namespace torch::nativert {
|
|
|
|
bool isSameDevice(const c10::Device& a, const c10::Device& b) {
|
|
if (a.is_cpu()) {
|
|
return b.is_cpu();
|
|
}
|
|
if (a.is_cuda()) {
|
|
if (b.is_cuda()) {
|
|
auto aIndex = a.has_index() ? a.index() : 0;
|
|
auto bIndex = b.has_index() ? b.index() : 0;
|
|
return aIndex == bIndex;
|
|
} else {
|
|
return false;
|
|
}
|
|
}
|
|
if (a.is_meta()) {
|
|
return b.is_meta();
|
|
}
|
|
TORCH_CHECK(false, "Unsupported device type", a, " and ", b);
|
|
return false;
|
|
}
|
|
} // namespace torch::nativert
|