[MPS] Cache multinomial_with_replacement graph (#86437)

Reuse existing RandomCachedGraph to keep RNG state as part of the graph
Add `CreateCachedGraphAs` convenience wrapper
Addresses https://github.com/pytorch/pytorch/pull/86342#pullrequestreview-1132197848
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86437
Approved by: https://github.com/kulinseth
This commit is contained in:
Nikita Shulga
2022-10-07 04:39:28 +00:00
committed by PyTorch MergeBot
parent 9ceadcadb2
commit 10aead9adc
2 changed files with 113 additions and 88 deletions

View File

@ -204,6 +204,11 @@ struct MPSGraphCache
return result;
}
template<typename T>
inline T* CreateCachedGraphAs(const std::string& key, CreateCachedGraphBlock createCacheBlock, void* view_ptr = nullptr) {
return static_cast<T *>(CreateCachedGraph(key, createCacheBlock, view_ptr));
}
MPSCachedGraph* LookUp(const std::string& key) const {
__block MPSCachedGraph* result = nullptr;

View File

@ -20,6 +20,8 @@ struct RandomCachedGraph : public MPSCachedGraph
stateValues[5] = static_cast<uint32_t>(seed);
stateValues[6] = static_cast<uint32_t>(seed >> 32);
}
// Only relevant for multinomial
MPSGraphTensor *probTensor = nil;
MPSGraphTensor *resultTensor = nil;
MPSGraphTensor *stateTensor = nil;
// used for Normal distributions only
@ -57,10 +59,10 @@ Tensor& random_mps_impl(Tensor& self, scalar_t val1, scalar_t val2,
@autoreleasepool {
string key = op_name + getTensorsStringKey({self}) + ":" + to_string(val1) + ":" + to_string(val2);
RandomCachedGraph* cachedGraph = static_cast<RandomCachedGraph *>(cache_->LookUp(key));
auto cachedGraph = cache_->LookUpAs<RandomCachedGraph>(key);
if (!cachedGraph) {
cachedGraph = static_cast<RandomCachedGraph *>(cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
cachedGraph = cache_->CreateCachedGraphAs<RandomCachedGraph>(key, ^ MPSCachedGraph * () {
RandomCachedGraph *newCachedGraph = nil;
@autoreleasepool {
@ -105,7 +107,7 @@ Tensor& random_mps_impl(Tensor& self, scalar_t val1, scalar_t val2,
newCachedGraph->resultTensor = castMPSTensor(mpsGraph, newCachedGraph->resultTensor, self.scalar_type());
}
return newCachedGraph;
}));
});
}
// update the Philox state values on each run of the same graph
cachedGraph->updatePhiloxCounters();
@ -369,108 +371,126 @@ Tensor& multinomial_with_replacement_mps_kernel(
auto result_v = inputSize == 1 ? result.view({numDist, n_sample}) : result;
MPSStream* stream = getCurrentMPSStream();
uint64_t seed_ = c10::detail::getNonDeterministicRandom(true);
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
@autoreleasepool {
MPSShape* prob_shape = getMPSShape(self_v);
MPSGraph* mpsGraph = make_mps_graph();
string key = "multinomial_with_replacement:" + getTensorsStringKey({self}) + ":" + to_string(n_sample);
auto cachedGraph = cache_->LookUpAs<RandomCachedGraph>(key);
if (!cachedGraph) {
cachedGraph = cache_->CreateCachedGraphAs<RandomCachedGraph>(key, ^ MPSCachedGraph * () {
RandomCachedGraph *newCachedGraph = nil;
@autoreleasepool {
MPSShape* prob_shape = getMPSShape(self_v);
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new RandomCachedGraph(mpsGraph);
newCachedGraph->stateTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[@7]);
auto prob_dtype = getMPSDataType(self_v.scalar_type());
auto prob_dtype = getMPSDataType(self_v.scalar_type());
// This is probability weights
MPSGraphTensor *probTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self_v.scalar_type()), prob_shape);
// This is probability weights
newCachedGraph->probTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self_v.scalar_type()), prob_shape);
MPSGraphTensor *sumProbs = [mpsGraph reductionSumWithTensor:probTensor
axis:-1
name:nil];
MPSGraphTensor *sumProbs = [mpsGraph reductionSumWithTensor:newCachedGraph->probTensor
axis:-1
name:nil];
MPSGraphTensor *normalizedProbs = [mpsGraph divisionWithPrimaryTensor:probTensor
secondaryTensor:sumProbs
name:nil];
MPSGraphTensor *normalizedProbs = [mpsGraph divisionWithPrimaryTensor:newCachedGraph->probTensor
secondaryTensor:sumProbs
name:nil];
auto ns_numCategories = [NSNumber numberWithInt:numCategories];
auto ns_numDist = [NSNumber numberWithInt:numDist];
auto ns_n_sample = [NSNumber numberWithInt:n_sample];
auto ns_numCategories = [NSNumber numberWithInt:numCategories];
auto ns_numDist = [NSNumber numberWithInt:numDist];
auto ns_n_sample = [NSNumber numberWithInt:n_sample];
MPSGraphTensor *ones = [mpsGraph constantWithScalar:1.0f
shape:@[ns_numCategories, ns_numCategories]
dataType:prob_dtype];
MPSGraphTensor *upperTriangle = [mpsGraph bandPartWithTensor:ones
numLower:0
numUpper:-1
MPSGraphTensor *ones = [mpsGraph constantWithScalar:1.0f
shape:@[ns_numCategories, ns_numCategories]
dataType:prob_dtype];
MPSGraphTensor *upperTriangle = [mpsGraph bandPartWithTensor:ones
numLower:0
numUpper:-1
name:nil];
MPSGraphTensor *upperProbRange = [mpsGraph matrixMultiplicationWithPrimaryTensor:normalizedProbs
secondaryTensor:upperTriangle
name:nil];
MPSGraphTensor *lowerProbRange = [mpsGraph subtractionWithPrimaryTensor:upperProbRange
secondaryTensor:normalizedProbs
name:nil];
upperProbRange = [mpsGraph reshapeTensor:upperProbRange
withShape:@[ns_numDist, @1, ns_numCategories]
name:nil];
lowerProbRange = [mpsGraph reshapeTensor:lowerProbRange
withShape:@[ns_numDist, @1, ns_numCategories]
name:nil];
MPSGraphRandomOpDescriptor *descriptor = [MPSGraphRandomOpDescriptor descriptorWithDistribution:MPSGraphRandomDistributionUniform
dataType:prob_dtype];
NSArray<MPSGraphTensor*> *generatorTensors = [mpsGraph randomTensorWithShape:@[ns_numDist, ns_n_sample, @1]
descriptor:descriptor
stateTensor:newCachedGraph->stateTensor
name:nil];
MPSGraphTensor *randomTensor = generatorTensors[0];
auto broadcastShape = @[ns_numDist ,ns_n_sample, ns_numCategories];
int broadcastShapeVals[3] = {numDist, n_sample, numCategories};
MPSGraphTensor *broadcastShapeTensor = [mpsGraph constantWithData:[NSData dataWithBytes:broadcastShapeVals length:sizeof(int) * broadcastShape.count]
shape:@[[NSNumber numberWithUnsignedInteger:broadcastShape.count]]
dataType:MPSDataTypeUInt32];
MPSGraphTensor *samplesTensor = [mpsGraph broadcastTensor:randomTensor
toShape:broadcastShape
name:nil];
MPSGraphTensor *sampleAbove = [mpsGraph greaterThanWithPrimaryTensor:samplesTensor
secondaryTensor:lowerProbRange
name:nil];
MPSGraphTensor *sampleBelow = [mpsGraph lessThanWithPrimaryTensor:samplesTensor
secondaryTensor:upperProbRange
name:nil];
MPSGraphTensor *sampleWithin = [mpsGraph logicalANDWithPrimaryTensor:sampleAbove
secondaryTensor:sampleBelow
name:nil];
MPSGraphTensor *sampleMask = [mpsGraph castTensor:sampleWithin
toType:MPSDataTypeInt32
name:@"sampleMask"];
MPSGraphTensor *categoriesTensor = [mpsGraph coordinateAlongAxis:-1
withShapeTensor:broadcastShapeTensor
name:nil];
MPSGraphTensor *binnedSamplesTensor = [mpsGraph multiplicationWithPrimaryTensor:categoriesTensor
secondaryTensor:sampleMask
name:nil];
MPSGraphTensor *reducedTensor = [mpsGraph reductionSumWithTensor:binnedSamplesTensor
axis:-1
name:nil];
MPSGraphTensor *reshapeTensor = [mpsGraph reshapeTensor:reducedTensor
withShape:@[ns_numDist ,ns_n_sample]
name:nil];
MPSGraphTensor *upperProbRange = [mpsGraph matrixMultiplicationWithPrimaryTensor:normalizedProbs
secondaryTensor:upperTriangle
name:nil];
newCachedGraph->resultTensor = [mpsGraph castTensor:reshapeTensor
toType:getMPSDataType(result.scalar_type())
name:@"resultTensor"];
}
return newCachedGraph;
});
}
// update the Philox state values on each run of the same graph
cachedGraph->updatePhiloxCounters();
// feed the updated state values to the graph
MPSNDArrayDescriptor *stateDesc = [MPSNDArrayDescriptor descriptorWithDataType: MPSDataTypeInt32 shape: @[@7]];
MPSNDArray *stateNDArray = [[[MPSNDArray alloc] initWithDevice: stream->device() descriptor: stateDesc] autorelease];
[stateNDArray writeBytes: &cachedGraph->stateValues[0] strideBytes: nil];
MPSGraphTensorData* stateTensorData = [[[MPSGraphTensorData alloc] initWithMPSNDArray: stateNDArray] autorelease];
MPSGraphTensor *lowerProbRange = [mpsGraph subtractionWithPrimaryTensor:upperProbRange
secondaryTensor:normalizedProbs
name:nil];
upperProbRange = [mpsGraph reshapeTensor:upperProbRange
withShape:@[ns_numDist, @1, ns_numCategories]
name:nil];
lowerProbRange = [mpsGraph reshapeTensor:lowerProbRange
withShape:@[ns_numDist, @1, ns_numCategories]
name:nil];
MPSGraphTensor *stateTensor = [mpsGraph randomPhiloxStateTensorWithSeed:seed_
name:nil];
MPSGraphRandomOpDescriptor *descriptor = [MPSGraphRandomOpDescriptor descriptorWithDistribution:MPSGraphRandomDistributionUniform
dataType:prob_dtype];
NSArray<MPSGraphTensor*> *generatorTensors = [mpsGraph randomTensorWithShape:@[ns_numDist, ns_n_sample, @1]
descriptor:descriptor
stateTensor:stateTensor
name:nil];
MPSGraphTensor *randomTensor = generatorTensors[0];
auto broadcastShape = @[ns_numDist ,ns_n_sample, ns_numCategories];
int broadcastShapeVals[3] = {numDist, n_sample, numCategories};
MPSGraphTensor *broadcastShapeTensor = [mpsGraph constantWithData:[NSData dataWithBytes:broadcastShapeVals length:sizeof(int) * broadcastShape.count]
shape:@[[NSNumber numberWithUnsignedInteger:broadcastShape.count]]
dataType:MPSDataTypeUInt32];
MPSGraphTensor *samplesTensor = [mpsGraph broadcastTensor:randomTensor
toShape:broadcastShape
name:nil];
MPSGraphTensor *sampleAbove = [mpsGraph greaterThanWithPrimaryTensor:samplesTensor
secondaryTensor:lowerProbRange
name:nil];
MPSGraphTensor *sampleBelow = [mpsGraph lessThanWithPrimaryTensor:samplesTensor
secondaryTensor:upperProbRange
name:nil];
MPSGraphTensor *sampleWithin = [mpsGraph logicalANDWithPrimaryTensor:sampleAbove
secondaryTensor:sampleBelow
name:nil];
MPSGraphTensor *sampleMask = [mpsGraph castTensor:sampleWithin
toType:MPSDataTypeInt32
name:@"sampleMask"];
MPSGraphTensor *categoriesTensor = [mpsGraph coordinateAlongAxis:-1
withShapeTensor:broadcastShapeTensor
name:nil];
MPSGraphTensor *binnedSamplesTensor = [mpsGraph multiplicationWithPrimaryTensor:categoriesTensor
secondaryTensor:sampleMask
name:nil];
MPSGraphTensor *reducedTensor = [mpsGraph reductionSumWithTensor:binnedSamplesTensor
axis:-1
name:nil];
MPSGraphTensor *reshapeTensor = [mpsGraph reshapeTensor:reducedTensor
withShape:@[ns_numDist ,ns_n_sample]
name:nil];
MPSGraphTensor *resultTensor = [mpsGraph castTensor:reshapeTensor
toType:getMPSDataType(result.scalar_type())
name:@"resultTensor"];
auto probPlaceholder = Placeholder(probTensor, self_v);
auto outputPlaceholder = Placeholder(resultTensor, result_v);
auto probPlaceholder = Placeholder(cachedGraph->probTensor, self_v);
auto outputPlaceholder = Placeholder(cachedGraph->resultTensor, result_v);
NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *feeds = @{
cachedGraph->stateTensor : stateTensorData,
probPlaceholder.getMPSGraphTensor() : probPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};
runMPSGraph(stream, mpsGraph, feeds, results);
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
}
return result;