[MPS] Cache multinomial_with_replacement graph (#86437)

Reuse existing RandomCachedGraph to keep RNG state as part of the graph
Add `CreateCachedGraphAs` convenience wrapper
Addresses https://github.com/pytorch/pytorch/pull/86342#pullrequestreview-1132197848
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86437
Approved by: https://github.com/kulinseth
This commit is contained in:
Nikita Shulga
2022-10-07 04:39:28 +00:00
committed by PyTorch MergeBot
parent 9ceadcadb2
commit 10aead9adc
2 changed files with 113 additions and 88 deletions

View File

@ -204,6 +204,11 @@ struct MPSGraphCache
return result;
}
template<typename T>
inline T* CreateCachedGraphAs(const std::string& key, CreateCachedGraphBlock createCacheBlock, void* view_ptr = nullptr) {
return static_cast<T *>(CreateCachedGraph(key, createCacheBlock, view_ptr));
}
MPSCachedGraph* LookUp(const std::string& key) const {
__block MPSCachedGraph* result = nullptr;

View File

@ -20,6 +20,8 @@ struct RandomCachedGraph : public MPSCachedGraph
stateValues[5] = static_cast<uint32_t>(seed);
stateValues[6] = static_cast<uint32_t>(seed >> 32);
}
// Only relevant for multinomial
MPSGraphTensor *probTensor = nil;
MPSGraphTensor *resultTensor = nil;
MPSGraphTensor *stateTensor = nil;
// used for Normal distributions only
@ -57,10 +59,10 @@ Tensor& random_mps_impl(Tensor& self, scalar_t val1, scalar_t val2,
@autoreleasepool {
string key = op_name + getTensorsStringKey({self}) + ":" + to_string(val1) + ":" + to_string(val2);
RandomCachedGraph* cachedGraph = static_cast<RandomCachedGraph *>(cache_->LookUp(key));
auto cachedGraph = cache_->LookUpAs<RandomCachedGraph>(key);
if (!cachedGraph) {
cachedGraph = static_cast<RandomCachedGraph *>(cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
cachedGraph = cache_->CreateCachedGraphAs<RandomCachedGraph>(key, ^ MPSCachedGraph * () {
RandomCachedGraph *newCachedGraph = nil;
@autoreleasepool {
@ -105,7 +107,7 @@ Tensor& random_mps_impl(Tensor& self, scalar_t val1, scalar_t val2,
newCachedGraph->resultTensor = castMPSTensor(mpsGraph, newCachedGraph->resultTensor, self.scalar_type());
}
return newCachedGraph;
}));
});
}
// update the Philox state values on each run of the same graph
cachedGraph->updatePhiloxCounters();
@ -369,22 +371,30 @@ Tensor& multinomial_with_replacement_mps_kernel(
auto result_v = inputSize == 1 ? result.view({numDist, n_sample}) : result;
MPSStream* stream = getCurrentMPSStream();
uint64_t seed_ = c10::detail::getNonDeterministicRandom(true);
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
@autoreleasepool {
string key = "multinomial_with_replacement:" + getTensorsStringKey({self}) + ":" + to_string(n_sample);
auto cachedGraph = cache_->LookUpAs<RandomCachedGraph>(key);
if (!cachedGraph) {
cachedGraph = cache_->CreateCachedGraphAs<RandomCachedGraph>(key, ^ MPSCachedGraph * () {
RandomCachedGraph *newCachedGraph = nil;
@autoreleasepool {
MPSShape* prob_shape = getMPSShape(self_v);
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new RandomCachedGraph(mpsGraph);
newCachedGraph->stateTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[@7]);
auto prob_dtype = getMPSDataType(self_v.scalar_type());
// This is probability weights
MPSGraphTensor *probTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self_v.scalar_type()), prob_shape);
newCachedGraph->probTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self_v.scalar_type()), prob_shape);
MPSGraphTensor *sumProbs = [mpsGraph reductionSumWithTensor:probTensor
MPSGraphTensor *sumProbs = [mpsGraph reductionSumWithTensor:newCachedGraph->probTensor
axis:-1
name:nil];
MPSGraphTensor *normalizedProbs = [mpsGraph divisionWithPrimaryTensor:probTensor
MPSGraphTensor *normalizedProbs = [mpsGraph divisionWithPrimaryTensor:newCachedGraph->probTensor
secondaryTensor:sumProbs
name:nil];
@ -414,13 +424,11 @@ Tensor& multinomial_with_replacement_mps_kernel(
withShape:@[ns_numDist, @1, ns_numCategories]
name:nil];
MPSGraphTensor *stateTensor = [mpsGraph randomPhiloxStateTensorWithSeed:seed_
name:nil];
MPSGraphRandomOpDescriptor *descriptor = [MPSGraphRandomOpDescriptor descriptorWithDistribution:MPSGraphRandomDistributionUniform
dataType:prob_dtype];
NSArray<MPSGraphTensor*> *generatorTensors = [mpsGraph randomTensorWithShape:@[ns_numDist, ns_n_sample, @1]
descriptor:descriptor
stateTensor:stateTensor
stateTensor:newCachedGraph->stateTensor
name:nil];
MPSGraphTensor *randomTensor = generatorTensors[0];
@ -457,20 +465,32 @@ Tensor& multinomial_with_replacement_mps_kernel(
MPSGraphTensor *reshapeTensor = [mpsGraph reshapeTensor:reducedTensor
withShape:@[ns_numDist ,ns_n_sample]
name:nil];
MPSGraphTensor *resultTensor = [mpsGraph castTensor:reshapeTensor
newCachedGraph->resultTensor = [mpsGraph castTensor:reshapeTensor
toType:getMPSDataType(result.scalar_type())
name:@"resultTensor"];
}
return newCachedGraph;
});
}
// update the Philox state values on each run of the same graph
cachedGraph->updatePhiloxCounters();
// feed the updated state values to the graph
MPSNDArrayDescriptor *stateDesc = [MPSNDArrayDescriptor descriptorWithDataType: MPSDataTypeInt32 shape: @[@7]];
MPSNDArray *stateNDArray = [[[MPSNDArray alloc] initWithDevice: stream->device() descriptor: stateDesc] autorelease];
[stateNDArray writeBytes: &cachedGraph->stateValues[0] strideBytes: nil];
MPSGraphTensorData* stateTensorData = [[[MPSGraphTensorData alloc] initWithMPSNDArray: stateNDArray] autorelease];
auto probPlaceholder = Placeholder(probTensor, self_v);
auto outputPlaceholder = Placeholder(resultTensor, result_v);
auto probPlaceholder = Placeholder(cachedGraph->probTensor, self_v);
auto outputPlaceholder = Placeholder(cachedGraph->resultTensor, result_v);
NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *feeds = @{
cachedGraph->stateTensor : stateTensorData,
probPlaceholder.getMPSGraphTensor() : probPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};
runMPSGraph(stream, mpsGraph, feeds, results);
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
}
return result;