mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-27 09:04:53 +08:00
[MPS] Cache multinomial_with_replacement graph (#86437)
Reuse existing RandomCachedGraph to keep RNG state as part of the graph Add `CreateCachedGraphAs` convenience wrapper Addresses https://github.com/pytorch/pytorch/pull/86342#pullrequestreview-1132197848 Pull Request resolved: https://github.com/pytorch/pytorch/pull/86437 Approved by: https://github.com/kulinseth
This commit is contained in:
committed by
PyTorch MergeBot
parent
9ceadcadb2
commit
10aead9adc
@ -204,6 +204,11 @@ struct MPSGraphCache
|
||||
return result;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
inline T* CreateCachedGraphAs(const std::string& key, CreateCachedGraphBlock createCacheBlock, void* view_ptr = nullptr) {
|
||||
return static_cast<T *>(CreateCachedGraph(key, createCacheBlock, view_ptr));
|
||||
}
|
||||
|
||||
MPSCachedGraph* LookUp(const std::string& key) const {
|
||||
|
||||
__block MPSCachedGraph* result = nullptr;
|
||||
|
||||
@ -20,6 +20,8 @@ struct RandomCachedGraph : public MPSCachedGraph
|
||||
stateValues[5] = static_cast<uint32_t>(seed);
|
||||
stateValues[6] = static_cast<uint32_t>(seed >> 32);
|
||||
}
|
||||
// Only relevant for multinomial
|
||||
MPSGraphTensor *probTensor = nil;
|
||||
MPSGraphTensor *resultTensor = nil;
|
||||
MPSGraphTensor *stateTensor = nil;
|
||||
// used for Normal distributions only
|
||||
@ -57,10 +59,10 @@ Tensor& random_mps_impl(Tensor& self, scalar_t val1, scalar_t val2,
|
||||
|
||||
@autoreleasepool {
|
||||
string key = op_name + getTensorsStringKey({self}) + ":" + to_string(val1) + ":" + to_string(val2);
|
||||
RandomCachedGraph* cachedGraph = static_cast<RandomCachedGraph *>(cache_->LookUp(key));
|
||||
auto cachedGraph = cache_->LookUpAs<RandomCachedGraph>(key);
|
||||
|
||||
if (!cachedGraph) {
|
||||
cachedGraph = static_cast<RandomCachedGraph *>(cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
|
||||
cachedGraph = cache_->CreateCachedGraphAs<RandomCachedGraph>(key, ^ MPSCachedGraph * () {
|
||||
RandomCachedGraph *newCachedGraph = nil;
|
||||
|
||||
@autoreleasepool {
|
||||
@ -105,7 +107,7 @@ Tensor& random_mps_impl(Tensor& self, scalar_t val1, scalar_t val2,
|
||||
newCachedGraph->resultTensor = castMPSTensor(mpsGraph, newCachedGraph->resultTensor, self.scalar_type());
|
||||
}
|
||||
return newCachedGraph;
|
||||
}));
|
||||
});
|
||||
}
|
||||
// update the Philox state values on each run of the same graph
|
||||
cachedGraph->updatePhiloxCounters();
|
||||
@ -369,108 +371,126 @@ Tensor& multinomial_with_replacement_mps_kernel(
|
||||
auto result_v = inputSize == 1 ? result.view({numDist, n_sample}) : result;
|
||||
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
uint64_t seed_ = c10::detail::getNonDeterministicRandom(true);
|
||||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
|
||||
@autoreleasepool {
|
||||
MPSShape* prob_shape = getMPSShape(self_v);
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
string key = "multinomial_with_replacement:" + getTensorsStringKey({self}) + ":" + to_string(n_sample);
|
||||
auto cachedGraph = cache_->LookUpAs<RandomCachedGraph>(key);
|
||||
if (!cachedGraph) {
|
||||
cachedGraph = cache_->CreateCachedGraphAs<RandomCachedGraph>(key, ^ MPSCachedGraph * () {
|
||||
RandomCachedGraph *newCachedGraph = nil;
|
||||
@autoreleasepool {
|
||||
MPSShape* prob_shape = getMPSShape(self_v);
|
||||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
newCachedGraph = new RandomCachedGraph(mpsGraph);
|
||||
newCachedGraph->stateTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[@7]);
|
||||
|
||||
auto prob_dtype = getMPSDataType(self_v.scalar_type());
|
||||
auto prob_dtype = getMPSDataType(self_v.scalar_type());
|
||||
|
||||
// This is probability weights
|
||||
MPSGraphTensor *probTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self_v.scalar_type()), prob_shape);
|
||||
// This is probability weights
|
||||
newCachedGraph->probTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self_v.scalar_type()), prob_shape);
|
||||
|
||||
MPSGraphTensor *sumProbs = [mpsGraph reductionSumWithTensor:probTensor
|
||||
axis:-1
|
||||
name:nil];
|
||||
MPSGraphTensor *sumProbs = [mpsGraph reductionSumWithTensor:newCachedGraph->probTensor
|
||||
axis:-1
|
||||
name:nil];
|
||||
|
||||
MPSGraphTensor *normalizedProbs = [mpsGraph divisionWithPrimaryTensor:probTensor
|
||||
secondaryTensor:sumProbs
|
||||
name:nil];
|
||||
MPSGraphTensor *normalizedProbs = [mpsGraph divisionWithPrimaryTensor:newCachedGraph->probTensor
|
||||
secondaryTensor:sumProbs
|
||||
name:nil];
|
||||
|
||||
auto ns_numCategories = [NSNumber numberWithInt:numCategories];
|
||||
auto ns_numDist = [NSNumber numberWithInt:numDist];
|
||||
auto ns_n_sample = [NSNumber numberWithInt:n_sample];
|
||||
auto ns_numCategories = [NSNumber numberWithInt:numCategories];
|
||||
auto ns_numDist = [NSNumber numberWithInt:numDist];
|
||||
auto ns_n_sample = [NSNumber numberWithInt:n_sample];
|
||||
|
||||
MPSGraphTensor *ones = [mpsGraph constantWithScalar:1.0f
|
||||
shape:@[ns_numCategories, ns_numCategories]
|
||||
dataType:prob_dtype];
|
||||
MPSGraphTensor *upperTriangle = [mpsGraph bandPartWithTensor:ones
|
||||
numLower:0
|
||||
numUpper:-1
|
||||
MPSGraphTensor *ones = [mpsGraph constantWithScalar:1.0f
|
||||
shape:@[ns_numCategories, ns_numCategories]
|
||||
dataType:prob_dtype];
|
||||
MPSGraphTensor *upperTriangle = [mpsGraph bandPartWithTensor:ones
|
||||
numLower:0
|
||||
numUpper:-1
|
||||
name:nil];
|
||||
MPSGraphTensor *upperProbRange = [mpsGraph matrixMultiplicationWithPrimaryTensor:normalizedProbs
|
||||
secondaryTensor:upperTriangle
|
||||
name:nil];
|
||||
|
||||
MPSGraphTensor *lowerProbRange = [mpsGraph subtractionWithPrimaryTensor:upperProbRange
|
||||
secondaryTensor:normalizedProbs
|
||||
name:nil];
|
||||
|
||||
upperProbRange = [mpsGraph reshapeTensor:upperProbRange
|
||||
withShape:@[ns_numDist, @1, ns_numCategories]
|
||||
name:nil];
|
||||
lowerProbRange = [mpsGraph reshapeTensor:lowerProbRange
|
||||
withShape:@[ns_numDist, @1, ns_numCategories]
|
||||
name:nil];
|
||||
|
||||
MPSGraphRandomOpDescriptor *descriptor = [MPSGraphRandomOpDescriptor descriptorWithDistribution:MPSGraphRandomDistributionUniform
|
||||
dataType:prob_dtype];
|
||||
NSArray<MPSGraphTensor*> *generatorTensors = [mpsGraph randomTensorWithShape:@[ns_numDist, ns_n_sample, @1]
|
||||
descriptor:descriptor
|
||||
stateTensor:newCachedGraph->stateTensor
|
||||
name:nil];
|
||||
MPSGraphTensor *randomTensor = generatorTensors[0];
|
||||
|
||||
auto broadcastShape = @[ns_numDist ,ns_n_sample, ns_numCategories];
|
||||
int broadcastShapeVals[3] = {numDist, n_sample, numCategories};
|
||||
MPSGraphTensor *broadcastShapeTensor = [mpsGraph constantWithData:[NSData dataWithBytes:broadcastShapeVals length:sizeof(int) * broadcastShape.count]
|
||||
shape:@[[NSNumber numberWithUnsignedInteger:broadcastShape.count]]
|
||||
dataType:MPSDataTypeUInt32];
|
||||
|
||||
MPSGraphTensor *samplesTensor = [mpsGraph broadcastTensor:randomTensor
|
||||
toShape:broadcastShape
|
||||
name:nil];
|
||||
MPSGraphTensor *sampleAbove = [mpsGraph greaterThanWithPrimaryTensor:samplesTensor
|
||||
secondaryTensor:lowerProbRange
|
||||
name:nil];
|
||||
MPSGraphTensor *sampleBelow = [mpsGraph lessThanWithPrimaryTensor:samplesTensor
|
||||
secondaryTensor:upperProbRange
|
||||
name:nil];
|
||||
MPSGraphTensor *sampleWithin = [mpsGraph logicalANDWithPrimaryTensor:sampleAbove
|
||||
secondaryTensor:sampleBelow
|
||||
name:nil];
|
||||
MPSGraphTensor *sampleMask = [mpsGraph castTensor:sampleWithin
|
||||
toType:MPSDataTypeInt32
|
||||
name:@"sampleMask"];
|
||||
MPSGraphTensor *categoriesTensor = [mpsGraph coordinateAlongAxis:-1
|
||||
withShapeTensor:broadcastShapeTensor
|
||||
name:nil];
|
||||
MPSGraphTensor *binnedSamplesTensor = [mpsGraph multiplicationWithPrimaryTensor:categoriesTensor
|
||||
secondaryTensor:sampleMask
|
||||
name:nil];
|
||||
MPSGraphTensor *reducedTensor = [mpsGraph reductionSumWithTensor:binnedSamplesTensor
|
||||
axis:-1
|
||||
name:nil];
|
||||
MPSGraphTensor *reshapeTensor = [mpsGraph reshapeTensor:reducedTensor
|
||||
withShape:@[ns_numDist ,ns_n_sample]
|
||||
name:nil];
|
||||
MPSGraphTensor *upperProbRange = [mpsGraph matrixMultiplicationWithPrimaryTensor:normalizedProbs
|
||||
secondaryTensor:upperTriangle
|
||||
name:nil];
|
||||
newCachedGraph->resultTensor = [mpsGraph castTensor:reshapeTensor
|
||||
toType:getMPSDataType(result.scalar_type())
|
||||
name:@"resultTensor"];
|
||||
}
|
||||
return newCachedGraph;
|
||||
});
|
||||
}
|
||||
// update the Philox state values on each run of the same graph
|
||||
cachedGraph->updatePhiloxCounters();
|
||||
// feed the updated state values to the graph
|
||||
MPSNDArrayDescriptor *stateDesc = [MPSNDArrayDescriptor descriptorWithDataType: MPSDataTypeInt32 shape: @[@7]];
|
||||
MPSNDArray *stateNDArray = [[[MPSNDArray alloc] initWithDevice: stream->device() descriptor: stateDesc] autorelease];
|
||||
[stateNDArray writeBytes: &cachedGraph->stateValues[0] strideBytes: nil];
|
||||
MPSGraphTensorData* stateTensorData = [[[MPSGraphTensorData alloc] initWithMPSNDArray: stateNDArray] autorelease];
|
||||
|
||||
MPSGraphTensor *lowerProbRange = [mpsGraph subtractionWithPrimaryTensor:upperProbRange
|
||||
secondaryTensor:normalizedProbs
|
||||
name:nil];
|
||||
|
||||
upperProbRange = [mpsGraph reshapeTensor:upperProbRange
|
||||
withShape:@[ns_numDist, @1, ns_numCategories]
|
||||
name:nil];
|
||||
lowerProbRange = [mpsGraph reshapeTensor:lowerProbRange
|
||||
withShape:@[ns_numDist, @1, ns_numCategories]
|
||||
name:nil];
|
||||
|
||||
MPSGraphTensor *stateTensor = [mpsGraph randomPhiloxStateTensorWithSeed:seed_
|
||||
name:nil];
|
||||
MPSGraphRandomOpDescriptor *descriptor = [MPSGraphRandomOpDescriptor descriptorWithDistribution:MPSGraphRandomDistributionUniform
|
||||
dataType:prob_dtype];
|
||||
NSArray<MPSGraphTensor*> *generatorTensors = [mpsGraph randomTensorWithShape:@[ns_numDist, ns_n_sample, @1]
|
||||
descriptor:descriptor
|
||||
stateTensor:stateTensor
|
||||
name:nil];
|
||||
MPSGraphTensor *randomTensor = generatorTensors[0];
|
||||
|
||||
auto broadcastShape = @[ns_numDist ,ns_n_sample, ns_numCategories];
|
||||
int broadcastShapeVals[3] = {numDist, n_sample, numCategories};
|
||||
MPSGraphTensor *broadcastShapeTensor = [mpsGraph constantWithData:[NSData dataWithBytes:broadcastShapeVals length:sizeof(int) * broadcastShape.count]
|
||||
shape:@[[NSNumber numberWithUnsignedInteger:broadcastShape.count]]
|
||||
dataType:MPSDataTypeUInt32];
|
||||
|
||||
MPSGraphTensor *samplesTensor = [mpsGraph broadcastTensor:randomTensor
|
||||
toShape:broadcastShape
|
||||
name:nil];
|
||||
MPSGraphTensor *sampleAbove = [mpsGraph greaterThanWithPrimaryTensor:samplesTensor
|
||||
secondaryTensor:lowerProbRange
|
||||
name:nil];
|
||||
MPSGraphTensor *sampleBelow = [mpsGraph lessThanWithPrimaryTensor:samplesTensor
|
||||
secondaryTensor:upperProbRange
|
||||
name:nil];
|
||||
MPSGraphTensor *sampleWithin = [mpsGraph logicalANDWithPrimaryTensor:sampleAbove
|
||||
secondaryTensor:sampleBelow
|
||||
name:nil];
|
||||
MPSGraphTensor *sampleMask = [mpsGraph castTensor:sampleWithin
|
||||
toType:MPSDataTypeInt32
|
||||
name:@"sampleMask"];
|
||||
MPSGraphTensor *categoriesTensor = [mpsGraph coordinateAlongAxis:-1
|
||||
withShapeTensor:broadcastShapeTensor
|
||||
name:nil];
|
||||
MPSGraphTensor *binnedSamplesTensor = [mpsGraph multiplicationWithPrimaryTensor:categoriesTensor
|
||||
secondaryTensor:sampleMask
|
||||
name:nil];
|
||||
MPSGraphTensor *reducedTensor = [mpsGraph reductionSumWithTensor:binnedSamplesTensor
|
||||
axis:-1
|
||||
name:nil];
|
||||
MPSGraphTensor *reshapeTensor = [mpsGraph reshapeTensor:reducedTensor
|
||||
withShape:@[ns_numDist ,ns_n_sample]
|
||||
name:nil];
|
||||
MPSGraphTensor *resultTensor = [mpsGraph castTensor:reshapeTensor
|
||||
toType:getMPSDataType(result.scalar_type())
|
||||
name:@"resultTensor"];
|
||||
|
||||
auto probPlaceholder = Placeholder(probTensor, self_v);
|
||||
auto outputPlaceholder = Placeholder(resultTensor, result_v);
|
||||
auto probPlaceholder = Placeholder(cachedGraph->probTensor, self_v);
|
||||
auto outputPlaceholder = Placeholder(cachedGraph->resultTensor, result_v);
|
||||
NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *feeds = @{
|
||||
cachedGraph->stateTensor : stateTensorData,
|
||||
probPlaceholder.getMPSGraphTensor() : probPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
|
||||
runMPSGraph(stream, mpsGraph, feeds, results);
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
}
|
||||
|
||||
return result;
|
||||
|
||||
Reference in New Issue
Block a user