Update placement utils and weights to handle meta device (#162842)

Summary:
This diff fixes two things which come up when testing a tgif-published pt2 model remote net:
1) Updates isSameDevice to handle meta device to avoid this error:
```
what():  Unsupported device typemeta and meta
Exception raised from isSameDevice at fbcode/caffe2/torch/nativert/executor/PlacementUtils.cpp:20
```

2. Updates xl weight v2 loading logic in Weights.cpp to handle non-TBE xl-weights. Today, we enforce the device is the same for an old weight and new weight when replacing with ModelRunnerAdapter.setAttr(). However, the way we replace non-TBE xl weights is to find any weights on "meta" device and then replace them with their correct weight with real device from xl_weights folder. Therefore, the new weight and old weight will always have different devices and the device check is invalid. I don't think we've run into this so far bc non-TBE xl weights have not been thoroughly tested until now.

Test Plan:
Run MRS you model merge net, which uses non-TBE xl weights. Confirm that before change #1 we get error:
```
Unsupported device typemeta and meta
```
Then after change #1 and before change #2 we get:
```
what():  Mismatched device for merge.user_tower.linear.weight: meta vs cpu
Exception raised from validateValue at fbcode/caffe2/torch/nativert/executor/Weights.cpp:374
```
After change run is successful
Command:
```
MODEL_ENTITY_ID=921242082
SNAPSHOT_ID=1269
module_name=merge
SAMPLE_INPUT_DIR=/data/users/georgiaphillips/models/921242082/${SNAPSHOT_ID}/${module_name}_archive/package/data/sample_inputs
buck2 run mode/dev-nosan -c fbcode.nvcc_arch=h100,a100 -c fbcode.enable_gpu_sections=true caffe2/torch/fb/model_transform/fx2trt/packaging:load_net_predictor -- --loadMode=Benchmark --inputNetFile=/data/users/$USER/models/${MODEL_ENTITY_ID}/${SNAPSHOT_ID}/${MODEL_ENTITY_ID}_${SNAPSHOT_ID}.predictor.${module_name} --moduleName=${module_name} --submodToDevice="merge|cuda0"  --benchmarkEnableProfiling=false --disableStaticRuntime=true --doNotRandomizeSampleInputs=true --benchmarkDontRebatchSamples=true --pytorch_predictor_sigmoid_static_dispatch_enable=false --pytorch_predictor_sigmoid_graph_passes_enable=false --sampleInputFilePath=${SAMPLE_INPUT_DIR}/${module_name}.pt
```

Rollback Plan:

Differential Revision: D80713052

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162842
Approved by: https://github.com/henryoier
This commit is contained in:
Georgia Phillips
2025-09-17 08:12:32 +00:00
committed by PyTorch MergeBot
parent a5419743c6
commit b229455ddd
3 changed files with 39 additions and 12 deletions

View File

@ -17,6 +17,9 @@ bool isSameDevice(const c10::Device& a, const c10::Device& b) {
return false;
}
}
if (a.is_meta()) {
return b.is_meta();
}
TORCH_CHECK(false, "Unsupported device type", a, " and ", b);
return false;
}

View File

@ -337,6 +337,13 @@ void Weights::loadStateDict(
void Weights::validateValue(const std::string& name, const at::Tensor& newValue)
const {
validateValue(name, newValue, /*skipDeviceCheck=*/false);
}
void Weights::validateValue(
const std::string& name,
const at::Tensor& newValue,
bool skipDeviceCheck) const {
auto& weightMeta = weightsMeta_.at(name);
TORCH_CHECK(
@ -360,23 +367,32 @@ void Weights::validateValue(const std::string& name, const at::Tensor& newValue)
" vs ",
newValue.dtype());
auto targetDevice = weightMeta.device();
if (targetDevice.is_cpu() && targetDevice.has_index()) {
LOG(WARNING) << "Target device is cpu but has index: " << targetDevice;
if (!skipDeviceCheck) {
auto targetDevice = weightMeta.device();
if (targetDevice.is_cpu() && targetDevice.has_index()) {
LOG(WARNING) << "Target device is cpu but has index: " << targetDevice;
}
TORCH_CHECK(
isSameDevice(targetDevice, newValue.device()),
"Mismatched device for ",
name,
": ",
targetDevice,
" vs ",
newValue.device());
}
TORCH_CHECK(
isSameDevice(targetDevice, newValue.device()),
"Mismatched device for ",
name,
": ",
targetDevice,
" vs ",
newValue.device());
}
void Weights::setValue(const std::string& name, const at::Tensor& newValue) {
setValue(name, newValue, /*skipDeviceCheck=*/false);
}
void Weights::setValue(
const std::string& name,
const at::Tensor& newValue,
bool skipDeviceCheck) {
if (allValues_.find(name) != allValues_.end()) {
validateValue(name, newValue);
validateValue(name, newValue, skipDeviceCheck);
} else {
LOG(WARNING) << name << " is not found in the registered weights";
}

View File

@ -66,6 +66,10 @@ class Weights {
* Replace the value stored at the weight with name "name".
*/
void setValue(const std::string& name, const at::Tensor& newValue);
void setValue(
const std::string& name,
const at::Tensor& newValue,
bool skipDeviceCheck);
/*
* Update the value stored at the weight with name "name".
@ -77,6 +81,10 @@ class Weights {
const std::unordered_map<std::string, at::Tensor>& newValues);
void validateValue(const std::string& name, const at::Tensor& newValue) const;
void validateValue(
const std::string& name,
const at::Tensor& newValue,
bool skipDeviceCheck) const;
void validateAllWeightsLoaded();