Update on "[WIP] Add a simple cache mechanism to accelerate torch.compile-for-eager"

This PR is a follow-up of RFC https://github.com/pytorch/pytorch/issues/115545.

In this PR, we are trying to provide a cache mechanism to accelerate torch.compile-for-eager. 




cc voznesenskym penguinwu jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
This commit is contained in:
Wang, Eikan
2024-03-12 14:56:22 +00:00
21 changed files with 614 additions and 1841 deletions

View File

@ -205,6 +205,7 @@ class Context final {
class UniformParamsBuffer final {
private:
Context* context_p_;
size_t nbytes_;
VulkanBuffer vulkan_buffer_;
public:
@ -213,6 +214,7 @@ class UniformParamsBuffer final {
template <typename Block>
UniformParamsBuffer(Context* context_p, const Block& block)
: context_p_(context_p),
nbytes_(sizeof(block)),
vulkan_buffer_(
context_p_->adapter_ptr()->vma().create_params_buffer(block)) {}
@ -231,6 +233,21 @@ class UniformParamsBuffer final {
VulkanBuffer& buffer() {
return vulkan_buffer_;
}
template <typename Block>
void update(const Block& block) {
if (sizeof(block) != nbytes_) {
VK_THROW(
"Attempted to update UniformParamsBuffer with data of different size");
}
// Fill the uniform buffer with data in block
{
MemoryMap mapping(vulkan_buffer_, MemoryAccessType::WRITE);
Block* data_ptr = mapping.template data<Block>();
*data_ptr = block;
}
}
};
class StorageBuffer final {
@ -238,6 +255,7 @@ class StorageBuffer final {
Context* context_p_;
ScalarType dtype_;
size_t numel_;
size_t nbytes_;
VulkanBuffer vulkan_buffer_;
public:
@ -249,8 +267,9 @@ class StorageBuffer final {
: context_p_(context_p),
dtype_(dtype),
numel_(numel),
nbytes_(element_size(dtype_) * numel_),
vulkan_buffer_(context_p_->adapter_ptr()->vma().create_storage_buffer(
element_size(dtype_) * numel_,
nbytes_,
gpuonly)) {}
StorageBuffer(const StorageBuffer&) = delete;
@ -270,6 +289,14 @@ class StorageBuffer final {
inline VulkanBuffer& buffer() {
return vulkan_buffer_;
}
inline size_t numel() {
return numel_;
}
inline size_t nbytes() {
return nbytes_;
}
};
bool available();

View File

@ -151,6 +151,10 @@ class VulkanBuffer final {
return (memory_.allocation != VK_NULL_HANDLE);
}
inline bool owns_memory() const {
return owns_memory_;
}
operator bool() const {
return (handle_ != VK_NULL_HANDLE);
}
@ -372,6 +376,10 @@ class VulkanImage final {
return (memory_.allocation != VK_NULL_HANDLE);
}
inline bool owns_memory() const {
return owns_memory_;
}
inline operator bool() const {
return (handles_.image != VK_NULL_HANDLE);
}

View File

@ -12,6 +12,9 @@
#define VK_KERNEL(shader_name) \
::at::native::vulkan::api::shader_registry().get_shader_info(#shader_name)
#define VK_KERNEL_FROM_STR(shader_name_str) \
::at::native::vulkan::api::shader_registry().get_shader_info(shader_name_str)
namespace at {
namespace native {
namespace vulkan {

View File

@ -318,8 +318,8 @@ api::UniformParamsBuffer make_metadata_uniform(
}
vTensor::BufferMetadata metadata{
api::utils::make_nchw_uvec4(sizes),
api::utils::make_nchw_uvec4(strides),
api::utils::make_whcn_uvec4(sizes),
api::utils::make_whcn_uvec4(strides),
api::utils::safe_downcast<uint32_t>(sizes.size()),
api::utils::safe_downcast<uint32_t>(api::utils::multiply_integers(sizes)),
};
@ -347,12 +347,13 @@ vTensor::vTensor(
strides_{calc_strides(sizes, memory_layout_, storage_type)},
gpu_sizes_{calc_gpu_sizes(sizes, memory_layout_, storage_type)},
gpu_strides_{calc_strides(gpu_sizes_, memory_layout_, storage_type)},
// Vulkan uniform buffer containing sizes and stride info
metadata_uniform_{make_metadata_uniform(
context,
gpu_sizes_,
gpu_strides_,
storage_type)},
virtual_extents_(
create_image_extents(gpu_sizes_, storage_type, memory_layout)),
// Utility Uniform Buffers that can be passed to shaders as arguments
metadata_uniform_(),
cpu_sizes_uniform_(nullptr),
gpu_sizes_uniform_(nullptr),
extents_uniform_(nullptr),
// Construct Tensor storage
view_(std::make_shared<vTensorStorage>(
context,
@ -377,12 +378,13 @@ vTensor::vTensor(
strides_{calc_strides(sizes, memory_layout_, storage_type)},
gpu_sizes_{calc_gpu_sizes(sizes, memory_layout_, storage_type)},
gpu_strides_{calc_strides(gpu_sizes_, memory_layout_, storage_type)},
virtual_extents_(
create_image_extents(gpu_sizes_, storage_type, memory_layout)),
// Vulkan uniform buffer containing sizes and stride info
metadata_uniform_{make_metadata_uniform(
context,
gpu_sizes_,
gpu_strides_,
storage_type)},
metadata_uniform_(),
cpu_sizes_uniform_(nullptr),
gpu_sizes_uniform_(nullptr),
extents_uniform_(nullptr),
// Quantization params
is_quantized_{true},
q_scale_{q_scale},
@ -425,10 +427,47 @@ api::VulkanBuffer& vTensor::buffer(
return view_->buffer_;
}
api::VulkanBuffer& vTensor::buffer_metadata() {
if (!metadata_uniform_.buffer()) {
metadata_uniform_ = make_metadata_uniform(
view_->context_, gpu_sizes_, gpu_strides_, storage_type());
}
return metadata_uniform_.buffer();
}
std::shared_ptr<api::UniformParamsBuffer> vTensor::cpu_sizes_ubo() {
if (!cpu_sizes_uniform_) {
cpu_sizes_uniform_.reset(new api::UniformParamsBuffer(
view_->context_, api::utils::make_whcn_ivec4(sizes_)));
}
return cpu_sizes_uniform_;
}
std::shared_ptr<api::UniformParamsBuffer> vTensor::gpu_sizes_ubo() {
if (!gpu_sizes_uniform_) {
gpu_sizes_uniform_.reset(new api::UniformParamsBuffer(
view_->context_, api::utils::make_whcn_ivec4(gpu_sizes_)));
}
return gpu_sizes_uniform_;
}
std::shared_ptr<api::UniformParamsBuffer> vTensor::extents_ubo() {
if (!extents_uniform_) {
extents_uniform_.reset(new api::UniformParamsBuffer(
view_->context_,
api::utils::uvec4(
{view_->extents_.data[0],
view_->extents_.data[1],
view_->extents_.data[2],
1u})));
}
return extents_uniform_;
}
vTensor::BufferMetadata vTensor::get_cpu_buffer_metadata() const {
return {
api::utils::make_nchw_uvec4(sizes_),
api::utils::make_nchw_uvec4(strides_),
api::utils::make_whcn_uvec4(sizes_),
api::utils::make_whcn_uvec4(strides_),
api::utils::safe_downcast<uint32_t>(sizes_.size()),
api::utils::safe_downcast<uint32_t>(
api::utils::multiply_integers(sizes_)),
@ -473,6 +512,65 @@ void vTensor::bind_allocation(const api::MemoryAllocation& allocation) {
}
}
void vTensor::update_size_metadata(const std::vector<int64_t>& new_sizes) {
sizes_ = new_sizes;
gpu_sizes_ = calc_gpu_sizes(sizes_, memory_layout_, storage_type());
virtual_extents_ =
create_image_extents(gpu_sizes_, storage_type(), memory_layout_);
if (cpu_sizes_uniform_) {
cpu_sizes_uniform_->update(api::utils::make_whcn_ivec4(sizes_));
}
if (gpu_sizes_uniform_) {
gpu_sizes_uniform_->update(api::utils::make_whcn_ivec4(gpu_sizes_));
}
if (extents_uniform_) {
extents_uniform_->update(api::utils::uvec4(
{virtual_extents_.data[0],
virtual_extents_.data[1],
virtual_extents_.data[2],
1u}));
}
}
void vTensor::reallocate(const std::vector<int64_t>& new_sizes) {
update_size_metadata(new_sizes);
view_->discard_and_reallocate(
calc_gpu_sizes(new_sizes, memory_layout_, storage_type()),
memory_layout_,
dtype_);
}
void vTensor::virtual_resize(const std::vector<int64_t>& new_sizes) {
update_size_metadata(new_sizes);
if (storage_type() == api::StorageType::BUFFER) {
if (gpu_nbytes() > view_->buffer_.mem_size()) {
VK_THROW(
"Cannot virtual_resize a vTensor with sizes that require a larger "
"buffer! reallocate() should be used instead.");
}
} else {
bool valid_resize = true;
if (virtual_extents_.data[0] > view_->extents_.data[0]) {
valid_resize = false;
}
if (virtual_extents_.data[1] > view_->extents_.data[1]) {
valid_resize = false;
}
if (virtual_extents_.data[2] > view_->extents_.data[2]) {
valid_resize = false;
}
if (!valid_resize) {
VK_THROW(
"Cannot virtual_resize a vTensor with sizes that require a larger "
"image texture! reallocate() should be used instead.");
}
}
}
//
// vTensorStorage
//
@ -569,11 +667,16 @@ vTensorStorage::vTensorStorage(
last_access_{} {}
vTensorStorage::~vTensorStorage() {
flush();
}
void vTensorStorage::flush() {
if (image_) {
context_->register_image_cleanup(image_);
} else if (buffer_) {
context_->register_buffer_cleanup(buffer_);
}
last_access_ = {};
}
void vTensorStorage::transition(
@ -663,6 +766,28 @@ void add_buffer_barrier(
}
}
void vTensorStorage::discard_and_reallocate(
const std::vector<int64_t>& gpu_sizes,
const api::GPUMemoryLayout gpu_memory_layout,
const api::ScalarType dtype) {
const bool image_owns_memory = image_.owns_memory();
const bool buffer_owns_memory = buffer_.owns_memory();
flush();
extents_ = create_image_extents(gpu_sizes, storage_type_, gpu_memory_layout);
image_ = allocate_image(
context_,
extents_,
storage_type_,
api::to_vkformat(dtype),
image_owns_memory);
buffer_length_ = api::utils::multiply_integers(gpu_sizes);
buffer_ = allocate_buffer(
context_, buffer_length_, storage_type_, dtype, buffer_owns_memory);
}
} // namespace vulkan
} // namespace native
} // namespace at

View File

@ -66,6 +66,9 @@ class vTensorStorage final {
LastAccess last_access_;
private:
// Registers underlying memory for cleanup
void flush();
// Memory barrier insertion
void transition(
api::PipelineBarrier&,
@ -79,6 +82,11 @@ class vTensorStorage final {
inline VkFormat texture_format() {
return image_.format();
}
void discard_and_reallocate(
const std::vector<int64_t>& gpu_sizes,
const api::GPUMemoryLayout gpu_memory_layout,
const api::ScalarType dtype);
};
class vTensor final {
@ -141,10 +149,29 @@ class vTensor final {
std::vector<int64_t> gpu_sizes_;
std::vector<int64_t> gpu_strides_;
// The extents that correspond to the tensor's size metadata. Note that this
// may not be the same as the extents of the underlying image texture because
// vTensor can be virtually resized via virtual_resize() which will cause it
// to be interpreted as a tensor with a different size.
api::utils::uvec3 virtual_extents_;
// A Vulkan uniform buffer containing sizes and strides of the GPU buffer that
// can be passed into a shader.
api::UniformParamsBuffer metadata_uniform_;
// A Vulkan uniform buffer containing the tensor sizes that can be passed into
// a shader.
std::shared_ptr<api::UniformParamsBuffer> cpu_sizes_uniform_;
// A Vulkan uniform buffer containing the GPU tensor sizes that can be passed
// into a shader. GPU sizes refers to the sizes of the tensor after padding
// has been applied to one dimension to align it to the next multiple of 4.
std::shared_ptr<api::UniformParamsBuffer> gpu_sizes_uniform_;
// A Vulkan uniform buffer containing the image extents of the underlying
// image texture that can be passed into a shader.
std::shared_ptr<api::UniformParamsBuffer> extents_uniform_;
// Quantization params
bool is_quantized_{false};
double q_scale_{1.0f};
@ -250,13 +277,36 @@ class vTensor final {
return gpu_strides_;
}
inline const api::utils::uvec3& virtual_extents() const {
return virtual_extents_;
}
/*
* Get a uniform buffer containing sizes and strides information of the GPU
* buffer
*/
inline api::VulkanBuffer& buffer_metadata() {
return metadata_uniform_.buffer();
}
api::VulkanBuffer& buffer_metadata();
/*
* Get a uniform buffer object containing the tensor sizes to use in a compute
* shader. Note that the UBO will be created the first time this function is
* called.
*/
std::shared_ptr<api::UniformParamsBuffer> cpu_sizes_ubo();
/*
* Get a uniform buffer object containing the tensor GPU sizes to use in a
* compute shader. Note that the UBO will be created the first time this
* function is called.
*/
std::shared_ptr<api::UniformParamsBuffer> gpu_sizes_ubo();
/*
* Get a uniform buffer object containing the image extents to use in a
* compute shader. Note that the UBO will be created the first time this
* function is called.
*/
std::shared_ptr<api::UniformParamsBuffer> extents_ubo();
/*
* Constructs a BufferMetdata struct based on the original sizes and strides
@ -308,7 +358,7 @@ class vTensor final {
* Returns numel but based on gpu_sizes_ instead of sizes_
*/
inline size_t gpu_numel() const {
return view_->buffer_length_;
return api::utils::multiply_integers(gpu_sizes_);
}
/*
@ -332,6 +382,27 @@ class vTensor final {
* Binds the underlying resource to the given memory allocation
*/
void bind_allocation(const api::MemoryAllocation& allocation);
private:
/*
* Update the size metadata of the vTensor to be new sizes. Should not be used
* directly, reallocate() or virtual_resize() should be used instead.
*/
void update_size_metadata(const std::vector<int64_t>& new_sizes);
public:
/*
* Discard the underlying VkImage or VkBuffer and re-allocate based on new
* tensor sizes
*/
void reallocate(const std::vector<int64_t>& new_sizes);
/*
* Perform a virtual resize of the vTensor by modifying the size metadata that
* gets used in compute shaders. This allows the shader to treat the
* underlying resource as if it were a different size.
*/
void virtual_resize(const std::vector<int64_t>& new_sizes);
};
void add_buffer_barrier(

View File

@ -328,10 +328,10 @@ inline ivec3 make_ivec3(uvec3 ints) {
}
/*
* Given an vector of up to 4 int64_t representing the sizes of a tensor,
* Given an vector of up to 4 uint64_t representing the sizes of a tensor,
* constructs a uvec4 containing those elements in reverse order.
*/
inline uvec4 make_nchw_uvec4(const std::vector<int64_t>& arr) {
inline uvec4 make_whcn_uvec4(const std::vector<int64_t>& arr) {
uint32_t w = safe_downcast<uint32_t>(val_at(-1, arr));
uint32_t h = safe_downcast<uint32_t>(val_at(-2, arr));
uint32_t c = safe_downcast<uint32_t>(val_at(-3, arr));
@ -340,6 +340,19 @@ inline uvec4 make_nchw_uvec4(const std::vector<int64_t>& arr) {
return {w, h, c, n};
}
/*
* Given an vector of up to 4 int64_t representing the sizes of a tensor,
* constructs an ivec4 containing those elements in reverse order.
*/
inline ivec4 make_whcn_ivec4(const std::vector<int64_t>& arr) {
int32_t w = val_at(-1, arr);
int32_t h = val_at(-2, arr);
int32_t c = val_at(-3, arr);
int32_t n = val_at(-4, arr);
return {w, h, c, n};
}
/*
* Wrapper around std::accumulate that accumulates values of a container of
* integral types into int64_t. Taken from `multiply_integers` in

View File

@ -331,7 +331,7 @@ class ProcessGroupNCCLTest(MultiProcessTestCase):
a = torch.tensor([[2, 4, 0], [8, 0, 12]]).to(self.rank)
self.assertEqual(tensor_list[0], a)
except RuntimeError as e:
if "allreduce_sparse is only available in the NCCL experimental branch." in str(e):
if "NCCL does not support all_reduce with sparse tensors" in str(e):
pass
else:
# Rethrow the exception if it's a different error
@ -4052,7 +4052,7 @@ class SparseCollective(MultiProcessTestCase):
loss.backward()
self.assertTrue(ddp_model.module.embedding.weight.grad.indices, indices)
except RuntimeError as e:
if "allreduce_sparse is only available in the NCCL experimental branch." in str(e):
if "NCCL does not support all_reduce with sparse tensors" in str(e):
pass
else:
# Rethrow the exception if it's a different error

View File

@ -972,6 +972,7 @@ class TestExport(TestCase):
self._test_export_same_as_eager(kw_func, args, kwargs)
@testing.expectedFailureSerDer # we don't save placeholder metadata
@testing.expectedFailureSerDerPreDispatch
@testing.expectedFailureNonStrict
def test_linear_conv(self):
class MyLinear(torch.nn.Module):
@ -1462,6 +1463,7 @@ class TestExport(TestCase):
self.assertEqual(buffer[2].shape, torch.Size([])) # num_batches_tracked
@testing.expectedFailureNonStrict
@testing.expectedFailureSerDerPreDispatch # tracked via: T181382045
def test_export_dynamo_config(self):
class MyModule(torch.nn.Module):
def __init__(self):
@ -1833,6 +1835,7 @@ def forward(self, arg_0):
)
@testing.expectedFailureNonStrict # non-strict does not add deferred runtime assertions
@testing.expectedFailureSerDerPreDispatch # .item call becomes aten.item in predispatch IR
def test_automatic_constrain_size(self):
class M(torch.nn.Module):
def forward(self, x, y):
@ -1888,6 +1891,7 @@ def forward(self, arg_0):
self.assertTrue(isinstance(node.meta["val"], (Tensor, int)))
@testing.expectedFailureNonStrict
@testing.expectedFailureSerDerPreDispatch # .item() becomes aten.item in predispatch IR
def test_export_with_inline_constraints(self):
class Module(torch.nn.Module):
def forward(self, x):
@ -2249,6 +2253,7 @@ def forward(self, arg_0):
)
@testing.expectedFailureSerDer # We don't preserve metadata on graph module
@testing.expectedFailureSerDerPreDispatch
@testing.expectedFailureNonStrict
def test_retrace_graph_level_meta_preservation(self):
class Foo(torch.nn.Module):
@ -2479,6 +2484,7 @@ def forward(self, arg_0):
):
exported_program.module()(torch.rand(2, 3), torch.rand(2, 3))
@testing.expectedFailureSerDerPreDispatch # linear shouldn't decompose
def test_export_decomps_simple(self):
class M(torch.nn.Module):
def __init__(self):
@ -3052,6 +3058,7 @@ def forward(self, arg_0):
self.assertEqual(ep.module()(*inputs), m(*inputs))
@testing.expectedFailureSerDer # symfloat nyi
@testing.expectedFailureSerDerPreDispatch # symfloat nyi
@testing.expectedFailureRetraceability
def test_sym_sqrt(self):
import math

View File

@ -9,6 +9,7 @@ except ImportError:
import testing
from torch.export import export, load, save
from torch.export._trace import _export
test_classes = {}
@ -22,10 +23,21 @@ def mocked_serder_export(*args, **kwargs):
return loaded_ep
def mocked_serder_export_pre_dispatch(*args, **kwargs):
ep = _export(*args, **kwargs, pre_dispatch=True)
buffer = io.BytesIO()
save(ep, buffer)
buffer.seek(0)
loaded_ep = load(buffer)
return loaded_ep
def make_dynamic_cls(cls):
suffix = "_serdes"
suffix_pre_dispatch = "_serdes_pre_dispatch"
cls_prefix = "SerDesExport"
cls_prefix_pre_dispatch = "SerDesExportPreDispatch"
test_class = testing.make_test_cls_with_mocked_export(
cls,
@ -35,11 +47,21 @@ def make_dynamic_cls(cls):
xfail_prop="_expected_failure_serdes",
)
test_class_pre_dispatch = testing.make_test_cls_with_mocked_export(
cls,
cls_prefix_pre_dispatch,
suffix_pre_dispatch,
mocked_serder_export_pre_dispatch,
xfail_prop="_expected_failure_serdes_pre_dispatch",
)
test_classes[test_class.__name__] = test_class
test_classes[test_class_pre_dispatch.__name__] = test_class_pre_dispatch
# REMOVING THIS LINE WILL STOP TESTS FROM RUNNING
globals()[test_class.__name__] = test_class
globals()[test_class_pre_dispatch.__name__] = test_class_pre_dispatch
test_class.__module__ = __name__
return test_class
test_class_pre_dispatch.__module__ = __name__
tests = [

View File

@ -58,3 +58,8 @@ def expectedFailureRetraceability(fn):
def expectedFailureSerDer(fn):
fn._expected_failure_serdes = True
return fn
def expectedFailureSerDerPreDispatch(fn):
fn._expected_failure_serdes_pre_dispatch = True
return fn

View File

@ -1182,6 +1182,35 @@ class MutationTests(torch._dynamo.test_case.TestCase):
["out_ptr"],
)
@make_mutation_test
def test_reduce_sum():
@triton.jit
def reduce_sum_kernel(a_ptr, c_ptr, stride_am, stride_an):
offs_am = tl.arange(0, 4)
offs_an = tl.arange(0, 4)
a_ptrs = a_ptr + (
offs_am[:, None] * stride_am + offs_an[None, :] * stride_an
)
a = tl.load(a_ptrs)
m = tl.sum(a, axis=1)
tl.store(c_ptr + tl.arange(0, 4), m)
return (
reduce_sum_kernel,
{
"a_ptr": torch.randn(4, 4),
"c_ptr": torch.randn(4),
"stride_am": 4,
"stride_an": 4,
},
# TODO(aakhundov): tt.reduce is now supported, but only
# in the new MLIR-based Triton analysis pass (not in the
# old TTIR string parsing-based one). change the line
# below to ["c_ptr"] when new Triton pin lands and this
# test starts failing.
["a_ptr", "c_ptr"],
)
@make_mutation_test
def test_argmax():
@triton.jit
@ -1204,7 +1233,11 @@ class MutationTests(torch._dynamo.test_case.TestCase):
"stride_am": 4,
"stride_an": 4,
},
# TODO(oulgen): tt.reduce closures are not implemented yet
# TODO(aakhundov): tt.reduce is now supported, but only
# in the new MLIR-based Triton analysis pass (not in the
# old TTIR string parsing-based one). change the line
# below to ["c_ptr"] when new Triton pin lands and this
# test starts failing.
["a_ptr", "c_ptr"],
)

View File

@ -66,6 +66,8 @@ CUSPARSE_SPMM_COMPLEX128_SUPPORTED = (
IS_WINDOWS and torch.version.cuda and version.parse(torch.version.cuda) > version.parse("11.2")
) or (not IS_WINDOWS and not TEST_WITH_ROCM)
HIPSPARSE_SPMM_COMPLEX128_SUPPORTED = torch.version.hip and version.parse(torch.version.hip.split("-")[0]) >= version.parse("6.0")
def all_sparse_layouts(test_name='layout', include_strided=False):
return parametrize(test_name, [
subtest(torch.strided, name='Strided'),

View File

@ -21,7 +21,7 @@ from torch.testing._internal.common_dtype import (
floating_types, all_types_and_complex_and, floating_and_complex_types, floating_types_and,
all_types_and_complex, floating_and_complex_types_and)
from torch.testing._internal.opinfo.definitions.sparse import validate_sample_input_sparse
from test_sparse import CUSPARSE_SPMM_COMPLEX128_SUPPORTED
from test_sparse import CUSPARSE_SPMM_COMPLEX128_SUPPORTED, HIPSPARSE_SPMM_COMPLEX128_SUPPORTED
import operator
if TEST_SCIPY:
@ -2024,7 +2024,9 @@ class TestSparseCSR(TestCase):
@dtypesIfCUDA(*floating_types_and(torch.complex64,
*[torch.bfloat16] if SM80OrLater else [],
*[torch.half] if SM53OrLater else [],
*[torch.complex128] if CUSPARSE_SPMM_COMPLEX128_SUPPORTED else []))
*[torch.complex128]
if CUSPARSE_SPMM_COMPLEX128_SUPPORTED or HIPSPARSE_SPMM_COMPLEX128_SUPPORTED
else []))
@precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6,
torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
def test_addmm_sizes_all_sparse_csr(self, device, dtype, m, n, k):

View File

@ -28,13 +28,76 @@ except ImportError:
CPP_H_NAME = "spv.h"
CPP_SRC_NAME = "spv.cpp"
DEFAULT_ENV = {
DEFAULT_ENV: Dict[str, Any] = {
"PRECISION": "highp",
"FLOAT_IMAGE_FORMAT": "rgba16f",
"INT_IMAGE_FORMAT": "rgba32i",
"UINT_IMAGE_FORMAT": "rgba32ui",
}
TYPES_ENV: Dict[str, Any] = {
"IMAGE_FORMAT": {
"float": "rgba32f",
"half": "rgba16f",
"int": "rgba32i",
"uint": "rgba32ui",
"int8": "rgba8i",
"uint8": "rgba8ui",
},
"IMAGE_T": {
3: {
"float": "image3D",
"half": "image3D",
"int": "iimage3D",
"uint": "uimage3D",
},
2: {
"float": "image2D",
"half": "image2D",
"int": "iimage2D",
"uint": "uimage2D",
},
},
"SAMPLER_T": {
3: {
"float": "sampler3D",
"half": "sampler3D",
"int": "isampler3D",
"uint": "usampler3D",
},
2: {
"float": "sampler2D",
"half": "sampler2D",
"int": "isampler2D",
"uint": "usampler2D",
},
},
"VEC4_T": {
"float": "vec4",
"half": "vec4",
"int": "ivec4",
"uint": "uvec4",
"int8": "vec4",
"uint8": "uvec4",
},
"T": {
"float": "float",
"half": "float",
"int": "int",
"uint": "uint",
"int8": "int",
"uint8": "uint8",
},
}
FUNCS_ENV: Dict[str, Any] = {
"GET_POS": {
3: lambda pos: pos,
2: lambda pos: f"{pos}.xy",
}
}
def extract_filename(path: str, keep_ext: bool = True) -> Any:
if keep_ext:
@ -671,7 +734,10 @@ def main(argv: List[str]) -> int:
)
options = parser.parse_args()
DEFAULT_ENV.update(TYPES_ENV)
DEFAULT_ENV.update(FUNCS_ENV)
env = DEFAULT_ENV
for key, value in parse_arg_env(options.env).items():
env[key] = value

View File

@ -113,6 +113,39 @@ graph_sizes_log = torch._logging.getArtifactLogger(__name__, "graph_sizes")
trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call")
@dataclass(frozen=True)
class VariableTrackerCacheKey:
vt_id: int
# Two different source can point to the same object. However, Dynamo handles
# globals and local source differently when it comes to guards and possibly
# some other parts as well. So, cache also relies on the source.
source: Source
class VariableTrackerCache:
def __init__(self):
self.cache = {}
def lookup(self, value, source):
key = VariableTrackerCacheKey(id(value), source)
if key not in self.cache:
return None
return self.cache[key]
def add(self, value, source, vt):
key = VariableTrackerCacheKey(id(value), source)
self.cache[key] = vt
def clone(self):
# Needed for copy and restore graph state
new_cache = VariableTrackerCache()
new_cache.cache.update(self.cache)
return new_cache
def clear(self):
self.cache.clear()
class OutputGraphState(NamedTuple):
input_source_to_var: Dict[Source, VariableTracker]
tracked_fakes: List[TrackedFake]
@ -122,6 +155,7 @@ class OutputGraphState(NamedTuple):
global_state: Optional[Dict[str, bool]]
param_name_to_source: Optional[Dict[str, Source]]
side_effects: SideEffects
variable_tracker_cache: VariableTrackerCache
timestamp: int
non_compliant_ops: Set[torch._ops.OpOverload]
compliant_custom_ops: Set[torch._ops.OpOverload]
@ -320,6 +354,9 @@ class OutputGraph(Checkpointable[OutputGraphState]):
# Stores the full fqn of a param or buffer to the relevant source.
self.param_name_to_source: Optional[Dict[str, Source]] = dict()
self.side_effects = SideEffects()
# Cached variable trackers. This makes symbolic analysis of LOAD_GLOBAL
# and LOAD_ATTR for same python objects free.
self.variable_tracker_cache = VariableTrackerCache()
self.code_options = dict(code_options)
self.output_instructions: List[Instruction] = []
# used to track nodes that are added between calls of copy_graphstate
@ -592,6 +629,7 @@ class OutputGraph(Checkpointable[OutputGraphState]):
global_state,
dict(self.param_name_to_source),
self.side_effects.clone(),
self.variable_tracker_cache.clone(),
self.timestamp,
set(self.non_compliant_ops),
set(self.compliant_custom_ops),
@ -610,6 +648,7 @@ class OutputGraph(Checkpointable[OutputGraphState]):
global_state,
self.param_name_to_source,
self.side_effects,
self.variable_tracker_cache,
self.timestamp,
self.non_compliant_ops,
self.compliant_custom_ops,
@ -1571,6 +1610,7 @@ class OutputGraph(Checkpointable[OutputGraphState]):
self.real_value_cache.clear()
self.input_name_to_proxy.clear()
self.side_effects.clear()
self.variable_tracker_cache.clear()
self.register_finalizer_fns.clear()
self.dynamo_flat_name_to_original_fqn.clear()
self.tracing_context.clear()

View File

@ -267,10 +267,17 @@ class VariableBuilder:
if dup_guard:
self.install_guards(dup_guard)
return side_effect_result
cached_vt = self.tx.output.variable_tracker_cache.lookup(value, self.source)
if cached_vt:
return cached_vt
vt = self._wrap(value)
vt.source = self.source
if self._can_lift_attrs_to_inputs(vt):
vt = self.tx.output.side_effects.track_object_existing(value, vt)
self.tx.output.variable_tracker_cache.add(value, self.source, vt)
return vt
def _can_lift_attrs_to_inputs(self, vt):

View File

@ -248,29 +248,50 @@ def ttir_to_functions(ttir_module) -> Dict[str, Dict[Intermediate, List[Op]]]:
fn_name = op.get_str_attr("sym_name")
functions[fn_name] = fn_ops
elif child_block_ids:
if name in ("scf.if", "scf.for", "scf.while"):
# for blocked control flow ops: inline the enclosed
# ops into the parent block + rewire the last op in
# each child block (yield) to return the scf result
yield_ops = []
if name in ("scf.if", "scf.for", "scf.while", "tt.reduce"):
# for blocked ops: inline the enclosed ops into
# the parent block + rewire the last op in each
# child block to return the block result
return_ops = []
for block_id in child_block_ids:
# the block args used as operands of the ops in the block
# (and nested blocks inlined in the current block by now)
# are replaced by new fake Intermediates to avoid "this
# operand is not returned by anything other op in the fn"
# error in the downstream analysis
for idx in block_id_to_block_arg_ids[block_id]:
next_fake_intermediate -= 1
replacements[idx] = Intermediate(next_fake_intermediate)
if name.startswith("scf."):
# the scf block args are ignored by the pass. but, as they
# may be used as operands of the ops inside the block
# (and nested blocks inlined in the current block by now),
# they are replaced by new fake Intermediates to avoid "this
# operand is not returned by any other op in the fn" error
# in the downstream analysis
for idx in block_id_to_block_arg_ids[block_id]:
next_fake_intermediate -= 1
replacements[idx] = Intermediate(next_fake_intermediate)
else:
# for tt.reduce, wire the block arguments to the op arguments
num_operands = len(operand_ids)
block_arg_ids = block_id_to_block_arg_ids[block_id]
assert len(block_arg_ids) == 2 * num_operands, (
"tt.reduce is expected to have twice as "
"many block arguments as op arguments: "
f"{operand_ids=}, {block_arg_ids=}."
)
for i, idx in enumerate(block_arg_ids):
# for a tt.reduce op with N arguments, the block
# arguments comprise N reduced values followed by
# N current values corresponding to the N op args
replacements[idx] = Intermediate(
operand_ids[i % num_operands]
)
if block_id in op_stack:
block_ops = op_stack.pop(block_id)
if not block_ops:
continue
last_ret, last_ops = block_ops.popitem()
if all(op.name == "scf.yield" for op in last_ops):
# if last_ops are scf.yield, treat them separately
yield_ops.extend(last_ops)
if all(
op.name in ("scf.yield", "tt.reduce.return")
for op in last_ops
):
# if last_ops are all return ops, treat them separately
return_ops.extend(last_ops)
else:
# otherwise, return last_ops to the block
block_ops[last_ret] = last_ops
@ -279,10 +300,9 @@ def ttir_to_functions(ttir_module) -> Dict[str, Dict[Intermediate, List[Op]]]:
scf_results = [Intermediate(idx) for idx in result_ids]
for scf_result in scf_results:
for yield_op in yield_ops:
op_stack[parent_block_id][scf_result].append(yield_op)
for return_op in return_ops:
op_stack[parent_block_id][scf_result].append(return_op)
else:
# TODO(oulgen): add support for tt.reduce
raise Exception(
f"Unknown blocked function: {name}. Can't capture the TTIR."
)

View File

@ -2002,7 +2002,7 @@ class CppCodeCache:
cpp_compile_command(
input=input_path,
output=output_path,
vec_isa=cls.vec_isa,
vec_isa=picked_vec_isa,
**cls.cpp_compile_command_flags,
)
)

File diff suppressed because it is too large Load Diff

View File

@ -1,9 +1,9 @@
import torch
from torch.library import Library, impl
from torch.ao.quantization.utils import determine_qparams, validate_qmin_qmax
from typing import Tuple
from torch._refs import _unsqueeze_multiple
from typing import Optional, Tuple
import torch
from torch._refs import _unsqueeze_multiple
from torch.ao.quantization.utils import determine_qparams, validate_qmin_qmax
from torch.library import impl, Library
# Note: decomposed means decomposed quantized tensor, using decomposed so that the
# name is not too long
@ -13,7 +13,7 @@ _DTYPE_TO_QVALUE_BOUNDS = {
torch.uint8: (0, 255),
torch.int8: (-128, 127),
torch.int16: (-(2**15), 2**15 - 1),
torch.int32: (-(2**31), 2**31 - 1)
torch.int32: (-(2**31), 2**31 - 1),
}
# Helper to check the passed in quant min and max are valid for the dtype
@ -60,13 +60,26 @@ def quantize_per_tensor(
"""
if input.dtype == torch.bfloat16:
input = input.to(torch.float32)
assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
_quant_min_max_bounds_check(quant_min, quant_max, dtype)
inv_scale = 1.0 / scale
return torch.clamp(torch.round(input * inv_scale) + zero_point, quant_min, quant_max).to(dtype)
@impl(quantized_decomposed_lib, "quantize_per_tensor", "Meta")
def quantize_per_tensor_meta(
input: torch.Tensor,
scale: float,
zero_point: int,
quant_min: int,
quant_max: int,
dtype: torch.dtype
) -> torch.Tensor:
if input.dtype == torch.bfloat16:
input = input.to(torch.float32)
assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
return torch.empty_like(input, dtype=dtype)
quantized_decomposed_lib.define(
"quantize_per_tensor.tensor(Tensor input, Tensor scale, Tensor zero_point, "
"int quant_min, int quant_max, ScalarType dtype) -> Tensor")
@ -90,7 +103,14 @@ def quantize_per_tensor_tensor(
return quantize_per_tensor(input, scale.item(), zero_point.item(), quant_min, quant_max, dtype)
@impl(quantized_decomposed_lib, "quantize_per_tensor.tensor", "Meta")
def quantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype):
def quantize_per_tensor_tensor_meta(
input: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
quant_min: int,
quant_max: int,
dtype: torch.dtype
) -> torch.Tensor:
if input.dtype == torch.bfloat16:
input = input.to(torch.float32)
assert zero_point.numel() == 1, f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
@ -122,7 +142,14 @@ def quantize_per_tensor_tensor2(
return quantize_per_tensor(input, scale.item(), zero_point.item(), quant_min.item(), quant_max.item(), dtype)
@impl(quantized_decomposed_lib, "quantize_per_tensor.tensor2", "Meta")
def quantize_per_tensor_tensor2_meta(input, scale, zero_point, quant_min, quant_max, dtype):
def quantize_per_tensor_tensor2_meta(
input: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
quant_min: torch.Tensor,
quant_max: torch.Tensor,
dtype: torch.dtype
) -> torch.Tensor:
return quantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype)
# Note: quant_min/quant_max/dtype are not used in the operator, but for now it's kept in
@ -131,7 +158,7 @@ def quantize_per_tensor_tensor2_meta(input, scale, zero_point, quant_min, quant_
# We will revisit this later if we found there are no use cases for it
quantized_decomposed_lib.define(
"dequantize_per_tensor(Tensor input, float scale, int zero_point, "
"int quant_min, int quant_max, ScalarType dtype) -> Tensor")
"int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None) -> Tensor")
@impl(quantized_decomposed_lib, "dequantize_per_tensor", "CompositeExplicitAutograd")
def dequantize_per_tensor(
@ -140,7 +167,9 @@ def dequantize_per_tensor(
zero_point: int,
quant_min: int,
quant_max: int,
dtype: torch.dtype
dtype: torch.dtype,
*,
out_dtype: Optional[torch.dtype] = None
) -> torch.Tensor:
""" Affine dequantization for the Tensor using the same quantization parameters to map
from quantized values to floating point values
@ -163,22 +192,40 @@ def dequantize_per_tensor(
dtype (torch.dtype): dtype for input Tensor (not used in computation,
reserved for pattern matching)
out_dtype (torch.dtype?): optional dtype for output Tensor
Returns:
dequantized float32 Tensor
"""
assert input.dtype == dtype, f"Expecting input to have dtype: {dtype}, but got {input.dtype}"
if out_dtype is None:
out_dtype = torch.float32
if dtype in _DTYPE_TO_QVALUE_BOUNDS:
# TODO: investigate why
# (input - zero_point).to(torch.float32) * scale
# failed the test
return (input.to(torch.float32) - zero_point) * scale
return (input.to(out_dtype) - zero_point) * scale
else:
raise ValueError(f"Unsupported dtype in dequantize_per_tensor: {dtype}")
@impl(quantized_decomposed_lib, "dequantize_per_tensor", "Meta")
def dequantize_per_tensor_meta(
input: torch.Tensor,
scale: torch.Tensor,
zero_pointe: torch.Tensor,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
*,
out_dtype: Optional[torch.dtype] = None
) -> torch.Tensor:
if out_dtype is None:
out_dtype = torch.float32
return torch.empty_like(input, dtype=out_dtype)
quantized_decomposed_lib.define(
"dequantize_per_tensor.tensor(Tensor input, Tensor scale, Tensor zero_point, "
"int quant_min, int quant_max, ScalarType dtype) -> Tensor")
"int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None) -> Tensor")
@impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor", "CompositeExplicitAutograd")
def dequantize_per_tensor_tensor(
@ -187,7 +234,9 @@ def dequantize_per_tensor_tensor(
zero_point: torch.Tensor,
quant_min: int,
quant_max: int,
dtype: torch.dtype
dtype: torch.dtype,
*,
out_dtype: Optional[torch.dtype] = None
) -> torch.Tensor:
""" Affine dequantization for the Tensor using the same quantization parameters to map
from quantized values to floating point values
@ -196,22 +245,33 @@ def dequantize_per_tensor_tensor(
"""
assert zero_point.numel() == 1, f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
assert scale.numel() == 1, f"Expecting scale tensor to be one element, but received : {scale.numel()}"
return dequantize_per_tensor(input, scale.item(), zero_point.item(), quant_min, quant_max, dtype)
return dequantize_per_tensor(input, scale.item(), zero_point.item(), quant_min, quant_max, dtype, out_dtype=out_dtype)
@impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor", "Meta")
def dequantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype):
def dequantize_per_tensor_tensor_meta(
input: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
*,
out_dtype: Optional[torch.dtype] = None
) -> torch.Tensor:
if out_dtype is None:
out_dtype = torch.float32
assert zero_point.numel() == 1, f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
assert scale.numel() == 1, f"Expecting scale tensor to be one element, but received : {scale.numel()}"
assert input.dtype == dtype, f"Expecting input to have dtype: {dtype}"
if dtype in _DTYPE_TO_QVALUE_BOUNDS:
return torch.empty_like(input, dtype=torch.float32)
return torch.empty_like(input, dtype=out_dtype)
else:
raise ValueError(f"Unsupported dtype in dequantize_per_tensor: {dtype}")
# TODO: remove other variants and keep this one
quantized_decomposed_lib.define(
"dequantize_per_tensor.tensor2(Tensor input, Tensor scale, Tensor zero_point, "
"Tensor quant_min, Tensor quant_max, ScalarType dtype) -> Tensor")
"Tensor quant_min, Tensor quant_max, ScalarType dtype, *, ScalarType? out_dtype=None) -> Tensor")
@impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor2", "CompositeExplicitAutograd")
def dequantize_per_tensor_tensor2(
@ -220,7 +280,9 @@ def dequantize_per_tensor_tensor2(
zero_point: torch.Tensor,
quant_min: torch.Tensor,
quant_max: torch.Tensor,
dtype: torch.dtype
dtype: torch.dtype,
*,
out_dtype: Optional[torch.dtype] = None
) -> torch.Tensor:
""" Affine dequantization for the Tensor using the same quantization parameters to map
from quantized values to floating point values
@ -229,11 +291,21 @@ def dequantize_per_tensor_tensor2(
"""
assert zero_point.numel() == 1, f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
assert scale.numel() == 1, f"Expecting scale tensor to be one element, but received : {scale.numel()}"
return dequantize_per_tensor(input, scale.item(), zero_point.item(), quant_min.item(), quant_max.item(), dtype)
return dequantize_per_tensor(
input, scale.item(), zero_point.item(), quant_min.item(), quant_max.item(), dtype, out_dtype=out_dtype)
@impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor2", "Meta")
def dequantize_per_tensor_tensor2_meta(input, scale, zero_point, quant_min, quant_max, dtype):
return dequantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype)
def dequantize_per_tensor_tensor2_meta(
input,
scale,
zero_point,
quant_min,
quant_max,
dtype,
*,
out_dtype: Optional[torch.dtype] = None
) -> torch.Tensor:
return dequantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype, out_dtype=out_dtype)
quantized_decomposed_lib.define(
"choose_qparams.tensor(Tensor input, int quant_min, int quant_max, "
@ -415,7 +487,7 @@ def quantize_per_channel_meta(
# We will revisit this later if we found there are no use cases for it
quantized_decomposed_lib.define(
"dequantize_per_channel(Tensor input, Tensor scales, Tensor zero_points, int axis, "
"int quant_min, int quant_max, ScalarType dtype) -> Tensor")
"int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None) -> Tensor")
@impl(quantized_decomposed_lib, "dequantize_per_channel", "CompositeExplicitAutograd")
def dequantize_per_channel(
@ -425,7 +497,9 @@ def dequantize_per_channel(
axis: int,
quant_min: int,
quant_max: int,
dtype: torch.dtype
dtype: torch.dtype,
*,
out_dtype: Optional[torch.dtype] = None
) -> torch.Tensor:
""" Affine per channel dequantization for the Tensor using the same quantization
parameters for each channel/axis to map from quantized values to floating point values
@ -450,20 +524,24 @@ def dequantize_per_channel(
dtype (torch.dtype): requested dtype for output Tensor (not used in computation,
reserved for pattern matching)
out_dtype (torch.dtype?): optional dtype for output Tensor
Returns:
dequantized float32 Tensor
"""
assert input.dtype == dtype, f"Expecting input to have dtype {dtype}, but got dtype: {input.dtype}"
if out_dtype is None:
out_dtype = torch.float32
assert axis < input.dim(), f"Expecting axis to be < {input.dim()}"
_quant_min_max_bounds_check(quant_min, quant_max, dtype)
input, permute_axis_list = _permute_to_axis_zero(input, axis)
res = torch.zeros_like(input, dtype=torch.float32)
res = torch.zeros_like(input, dtype=out_dtype)
for i in range(input.size(0)):
# TODO: investigate why
# (input[i] - zero_points[i]).to(torch.float32) * scales[i]
# (input[i] - zero_points[i]).to(out_dtype) * scales[i]
# failed the test
res[i] = (input[i].to(torch.float32) - zero_points[i]) * scales[i]
res[i] = (input[i].to(out_dtype) - zero_points[i]) * scales[i]
out = res.permute(tuple(permute_axis_list))
return out
@ -476,12 +554,16 @@ def dequantize_per_channel_meta(
axis: int,
quant_min: int,
quant_max: int,
dtype: torch.dtype
dtype: torch.dtype,
*,
out_dtype: Optional[torch.dtype] = None
) -> torch.Tensor:
assert input.dtype == dtype, f"Expecting input to have dtype {dtype}, but got dtype: {input.dtype}"
if out_dtype is None:
out_dtype = torch.float32
assert axis < input.dim(), f"Expecting axis to be < {input.dim()}"
_quant_min_max_bounds_check(quant_min, quant_max, dtype)
return torch.empty_like(input, dtype=torch.float32)
return torch.empty_like(input, dtype=out_dtype)
quantized_decomposed_lib.define(
"fake_quant_per_channel(Tensor input, Tensor scales, Tensor zero_points, int axis, "

View File

@ -2918,7 +2918,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce_sparse(
// If the nccl branch is not "exp" then we just error
C10_THROW_ERROR(
Error,
"allreduce_sparse is only available in the NCCL experimental branch.");
"NCCL does not support all_reduce with sparse tensors. Please use dense tensors instead.");
#endif
}