mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-16 23:44:53 +08:00
MPS: Fix clamp scalar cache key to store floats in hex representation (#167777)
Fixes #167767. Original issue was that using std::to_string(value) does not work intended here if the value is smaller than 1e-6. The caching keys ended up as `clamp_out_mps_min:0.000000_scalar::f32[1]` instead of `clamp_out_mps_min:0.0000001_scalar::f32[1]`. After the change the values are stored as the hex representation for the floating point number. So for min_value 1e-7 the key will be `impl_min:0x1.ad7f2ap-24_scalar::f32[1]` and for min_value 0.0 `clamp_out_mps_min:0x0p+0_scalar::f32[1]` Output of the repro code before the change: ``` tensor([0.], device='mps:0') tensor([0.], device='mps:0') tensor([0.], device='mps:0') tensor([0.], device='mps:0') tensor([0.], device='mps:0') tensor([1.0000e-07], device='mps:0') tensor([0.], device='mps:0') tensor([1.0000e-07], device='mps:0') ``` Output for the repro code after the change: ``` tensor([0.], device='mps:0') tensor([1.0000e-07], device='mps:0') tensor([0.], device='mps:0') tensor([1.0000e-07], device='mps:0') tensor([0.], device='mps:0') tensor([1.0000e-07], device='mps:0') tensor([0.], device='mps:0') tensor([1.0000e-07], device='mps:0') ``` which matches the expected CPU reference. Snippet to test with: ``` import torch device='mps' dtype=torch.float32 a = torch.zeros(1, device=device, dtype=dtype) # the following line triggers the incorrect behavior, when commented, the remainder of the script appears to work as expected a_clamped = a.clamp(min=0.0) b = torch.zeros(1, device=device) print(b) c = b.clamp(min=1e-7) print(c) b = torch.zeros(1, device=device) print(b) c = b.clamp(min=1e-7, max=None) print(c) b = torch.zeros(1, device=device) print(b) c = b.clamp(min=1e-7, max=torch.inf) print(c) b = torch.zeros(1, device=device) print(b) c = b.clamp_min(1e-7) print(c) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/167777 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
de0d69b2c4
commit
3d7a8b7e61
@ -82,6 +82,7 @@ NSArray<NSNumber*>* getTensorAxes(const TensorBase& t);
|
||||
NSArray<NSNumber*>* getTensorAxes(const IntArrayRef& sizes, at::OptionalIntArrayRef dim);
|
||||
std::string getMPSShapeString(MPSShape* shape);
|
||||
std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype = true, bool exclude_shape = false);
|
||||
std::string to_hex_key(float);
|
||||
std::string getArrayRefString(const IntArrayRef s);
|
||||
// use has_storage() on the returned tensor to determine if src actually is a view
|
||||
Tensor gatherViewTensor(const Tensor& src, Tensor& dst);
|
||||
|
||||
@ -301,6 +301,10 @@ std::string getArrayRefString(const IntArrayRef s) {
|
||||
return fmt::to_string(fmt::join(s, ","));
|
||||
}
|
||||
|
||||
std::string to_hex_key(float f) {
|
||||
return fmt::format("{:a}", f);
|
||||
}
|
||||
|
||||
std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype, bool exclude_shape) {
|
||||
fmt::basic_memory_buffer<char, 100> buffer;
|
||||
auto buf_iterator = std::back_inserter(buffer);
|
||||
|
||||
@ -244,8 +244,8 @@ static void clamp_scalar_out_mps(const Tensor& input_t,
|
||||
|
||||
@autoreleasepool {
|
||||
// the optional min/max refs could affect how we build the cached graph
|
||||
std::string key = op_name + (has_min ? ("_min:" + std::to_string(min_scalar)) : "") +
|
||||
(has_max ? ("_max:" + std::to_string(max_scalar)) : "") + "_scalar:" + getTensorsStringKey({input_t});
|
||||
std::string key = op_name + (has_min ? ("_min:" + to_hex_key(min_scalar)) : "") +
|
||||
(has_max ? ("_max:" + to_hex_key(max_scalar)) : "") + "_scalar:" + getTensorsStringKey({input_t});
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
if (has_min)
|
||||
newCachedGraph->minTensor = [mpsGraph constantWithScalar:min_scalar
|
||||
|
||||
Reference in New Issue
Block a user