mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-01 13:34:57 +08:00
Update
[ghstack-poisoned]
This commit is contained in:
@ -1,88 +1,78 @@
|
||||
#include <ATen/cuda/CUDAGreenContext.h>
|
||||
|
||||
#if defined(CUDA_VERSION) && !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
#include <c10/cuda/driver_api.h>
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
#define HAS_CUDA_GREEN_CONTEXT() 1
|
||||
#else
|
||||
#define HAS_CUDA_GREEN_CONTEXT() 0
|
||||
#endif
|
||||
|
||||
namespace at::cuda {
|
||||
GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
int driver_version;
|
||||
C10_CUDA_CHECK(cudaDriverGetVersion(&driver_version));
|
||||
TORCH_CHECK(
|
||||
driver_version >= 12080, "cuda driver too old to use green context!");
|
||||
CUcontext pctx = nullptr;
|
||||
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&pctx));
|
||||
if (C10_UNLIKELY(!pctx)) {
|
||||
TORCH_WARN(
|
||||
"Attempted to create a green context but"
|
||||
" there was no primary context! Creating a primary context...");
|
||||
|
||||
GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
|
||||
#if HAS_CUDA_GREEN_CONTEXT()
|
||||
int driver_version;
|
||||
C10_CUDA_CHECK(cudaDriverGetVersion(&driver_version));
|
||||
TORCH_CHECK(
|
||||
driver_version >= 12080, "cuda driver too old to use green context!");
|
||||
CUcontext pctx = nullptr;
|
||||
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&pctx));
|
||||
if (C10_UNLIKELY(!pctx)) {
|
||||
TORCH_WARN(
|
||||
"Attempted to create a green context but"
|
||||
" there was no primary context! Creating a primary context...");
|
||||
cudaFree(0);
|
||||
}
|
||||
|
||||
cudaFree(0);
|
||||
}
|
||||
CUdevice device;
|
||||
device_id_ = device_id;
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuDeviceGet_(&device, device_id));
|
||||
|
||||
CUdevice device;
|
||||
device_id_ = device_id;
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuDeviceGet_(&device, device_id));
|
||||
// Get device resources
|
||||
CUdevResource device_resource;
|
||||
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuDeviceGetDevResource_(
|
||||
device, &device_resource, CU_DEV_RESOURCE_TYPE_SM));
|
||||
|
||||
// Get device resources
|
||||
CUdevResource device_resource;
|
||||
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuDeviceGetDevResource_(
|
||||
device, &device_resource, CU_DEV_RESOURCE_TYPE_SM));
|
||||
// Split resources
|
||||
std::vector<CUdevResource> result(1);
|
||||
auto result_data = result.data();
|
||||
unsigned int nb_groups = 1;
|
||||
CUdevResource remaining;
|
||||
|
||||
// Split resources
|
||||
std::vector<CUdevResource> result(1);
|
||||
auto result_data = result.data();
|
||||
unsigned int nb_groups = 1;
|
||||
CUdevResource remaining;
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuDevSmResourceSplitByCount_(
|
||||
result_data,
|
||||
&nb_groups,
|
||||
&device_resource,
|
||||
&remaining,
|
||||
0, // default flags
|
||||
num_sms));
|
||||
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuDevSmResourceSplitByCount_(
|
||||
result_data,
|
||||
&nb_groups,
|
||||
&device_resource,
|
||||
&remaining,
|
||||
0, // default flags
|
||||
num_sms));
|
||||
TORCH_CHECK(nb_groups == 1, "Failed to create single resource group");
|
||||
|
||||
TORCH_CHECK(nb_groups == 1, "Failed to create single resource group");
|
||||
// Generate resource descriptor
|
||||
CUdevResourceDesc desc;
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuDevResourceGenerateDesc_(
|
||||
&desc, result_data, 1));
|
||||
|
||||
// Generate resource descriptor
|
||||
CUdevResourceDesc desc;
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuDevResourceGenerateDesc_(
|
||||
&desc, result_data, 1));
|
||||
// Create green context
|
||||
// CU_GREEN_CTX_DEFAULT_STREAM is required per docs:
|
||||
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html
|
||||
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxCreate_(
|
||||
&green_ctx_, desc, device, CU_GREEN_CTX_DEFAULT_STREAM));
|
||||
|
||||
// Create green context
|
||||
// CU_GREEN_CTX_DEFAULT_STREAM is required per docs:
|
||||
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html
|
||||
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxCreate_(
|
||||
&green_ctx_, desc, device, CU_GREEN_CTX_DEFAULT_STREAM));
|
||||
|
||||
// Convert to regular context
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuCtxFromGreenCtx_(&context_, green_ctx_));
|
||||
TORCH_CHECK(context_, "Green ctx conversion to regular ctx failed!");
|
||||
// Convert to regular context
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuCtxFromGreenCtx_(&context_, green_ctx_));
|
||||
TORCH_CHECK(context_, "Green ctx conversion to regular ctx failed!");
|
||||
#else
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
#endif
|
||||
}
|
||||
|
||||
std::unique_ptr<GreenContext> GreenContext::create(
|
||||
uint32_t num_sms,
|
||||
std::optional<uint32_t> device_id) {
|
||||
#if HAS_CUDA_GREEN_CONTEXT()
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
if (!device_id.has_value()) {
|
||||
device_id = at::cuda::current_device();
|
||||
}
|
||||
return std::unique_ptr<GreenContext>(new GreenContext(device_id.value(), num_sms));
|
||||
return std::make_unique<GreenContext>(device_id.value(), num_sms);
|
||||
#else
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
#endif
|
||||
@ -90,7 +80,7 @@ GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
|
||||
|
||||
// Implement move operations
|
||||
GreenContext::GreenContext(GreenContext&& other) noexcept{
|
||||
#if HAS_CUDA_GREEN_CONTEXT()
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
device_id_ = std::exchange(other.device_id_, -1);
|
||||
green_ctx_ = std::exchange(other.green_ctx_, nullptr);
|
||||
context_ = std::exchange(other.context_, nullptr);
|
||||
@ -101,7 +91,7 @@ GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
|
||||
}
|
||||
|
||||
GreenContext& GreenContext::operator=(GreenContext&& other) noexcept{
|
||||
#if HAS_CUDA_GREEN_CONTEXT()
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
if (this != &other) {
|
||||
// Clean up current resources
|
||||
if (green_ctx_) {
|
||||
@ -130,7 +120,7 @@ GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
|
||||
}
|
||||
|
||||
GreenContext::~GreenContext() noexcept{
|
||||
#if HAS_CUDA_GREEN_CONTEXT()
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_));
|
||||
#else
|
||||
@ -138,9 +128,25 @@ GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
|
||||
#endif
|
||||
}
|
||||
|
||||
// Get the underlying CUDA context
|
||||
CUcontext GreenContext::getContext() const {
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
return context_;
|
||||
#else
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
#endif
|
||||
}
|
||||
|
||||
// Get the underlying green context
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
CUgreenCtx GreenContext::getGreenContext() const {
|
||||
return green_ctx_;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Make this context current
|
||||
void GreenContext::setContext() {
|
||||
#if HAS_CUDA_GREEN_CONTEXT()
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
auto current_stream = c10::cuda::getCurrentCUDAStream();
|
||||
parent_stream_ = current_stream.stream();
|
||||
|
||||
@ -169,7 +175,7 @@ GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
|
||||
}
|
||||
|
||||
void GreenContext::popContext() {
|
||||
#if HAS_CUDA_GREEN_CONTEXT()
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
// see above note about stream being hardcoded to the default stream
|
||||
at::cuda::CUDAEvent ev;
|
||||
ev.record(c10::cuda::getCurrentCUDAStream());
|
||||
|
||||
@ -1,38 +1,53 @@
|
||||
#pragma once
|
||||
#include <ATen/cuda/CUDAEvent.h>
|
||||
#include <cuda.h>
|
||||
|
||||
// Forward declare green context as opaque ptr
|
||||
typedef struct CUgreenCtx_st* CUgreenCtx;
|
||||
#if defined(CUDA_VERSION) && !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
#include <c10/cuda/driver_api.h>
|
||||
#include <cuda.h>
|
||||
#include <memory>
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
#define CUDA_HAS_GREEN_CONTEXT 1
|
||||
#else
|
||||
#define CUDA_HAS_GREEN_CONTEXT 0
|
||||
#endif
|
||||
|
||||
namespace at::cuda {
|
||||
|
||||
class TORCH_CUDA_CPP_API GreenContext {
|
||||
public:
|
||||
// Green context creation
|
||||
static std::unique_ptr<GreenContext> create(
|
||||
uint32_t num_sms,
|
||||
std::optional<uint32_t> device_id);
|
||||
~GreenContext() noexcept;
|
||||
GreenContext(uint32_t device_id, uint32_t num_sms);
|
||||
|
||||
static std::unique_ptr<GreenContext> create(uint32_t num_sms, std::optional<uint32_t> device_id);
|
||||
|
||||
// Delete copy constructor and assignment
|
||||
GreenContext(const GreenContext&) = delete;
|
||||
GreenContext& operator=(const GreenContext&) = delete;
|
||||
|
||||
// Implement move operations
|
||||
GreenContext(GreenContext&& other) noexcept;
|
||||
GreenContext& operator=(GreenContext&& other) noexcept;
|
||||
~GreenContext() noexcept;
|
||||
|
||||
// Get the underlying CUDA context
|
||||
CUcontext getContext() const;
|
||||
|
||||
// Get the underlying green context
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
CUgreenCtx getGreenContext() const;
|
||||
#endif
|
||||
|
||||
// Make this context current
|
||||
void setContext();
|
||||
|
||||
void popContext();
|
||||
|
||||
private:
|
||||
GreenContext(uint32_t device_id, uint32_t num_sms);
|
||||
// Implement move operations
|
||||
GreenContext(GreenContext&& other) noexcept;
|
||||
GreenContext& operator=(GreenContext&& other) noexcept;
|
||||
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
int32_t device_id_ = -1;
|
||||
CUgreenCtx green_ctx_ = nullptr;
|
||||
CUcontext context_ = nullptr;
|
||||
cudaStream_t parent_stream_ = nullptr;
|
||||
#endif
|
||||
};
|
||||
} // namespace at::cuda
|
||||
|
||||
@ -4222,6 +4222,7 @@ class TestCudaMallocAsync(TestCase):
|
||||
ss = torch.cuda.memory._snapshot()
|
||||
|
||||
trace_plot(ss)
|
||||
trace_plot(ss, filter_freed=True)
|
||||
segment_plot(ss)
|
||||
text = json.dumps(ss)
|
||||
|
||||
|
||||
@ -1791,7 +1791,7 @@ def rewrite_signature(
|
||||
for i, val in enumerate(sources):
|
||||
dict_of_source_vals[id(val)] = i
|
||||
|
||||
for val in candidates:
|
||||
for i, val in enumerate(candidates):
|
||||
if isinstance(val, tuple(common_constant_types)):
|
||||
matched_elements_positions.append(None)
|
||||
elif id(val) not in dict_of_source_vals:
|
||||
|
||||
@ -317,7 +317,7 @@ class GuardManagerWrapper:
|
||||
is_diff_guard_node = (
|
||||
node.get_source() in self.diff_guard_sources or node.fail_count() > 0
|
||||
)
|
||||
for _idx, (key_mgr, val_mgr) in sorted(
|
||||
for idx, (key_mgr, val_mgr) in sorted(
|
||||
node.get_key_value_managers().items()
|
||||
):
|
||||
is_diff_guard_node |= visit(key_mgr) | visit(val_mgr)
|
||||
@ -440,7 +440,7 @@ class GuardManagerWrapper:
|
||||
is_subtree_tag_safe = True
|
||||
|
||||
# Recurse to get the tag safe roots from subtree.
|
||||
for _idx, (key_mgr, val_mgr) in sorted(
|
||||
for idx, (key_mgr, val_mgr) in sorted(
|
||||
node.get_key_value_managers().items()
|
||||
):
|
||||
if key_mgr is not None:
|
||||
@ -448,7 +448,9 @@ class GuardManagerWrapper:
|
||||
if val_mgr is not None:
|
||||
tag_safe_roots.extend(visit(val_mgr))
|
||||
|
||||
for key_mgr, val_mgr in node.get_key_value_managers().values():
|
||||
for idx, (key_mgr, val_mgr) in sorted(
|
||||
node.get_key_value_managers().items()
|
||||
):
|
||||
if key_mgr:
|
||||
is_subtree_tag_safe &= key_mgr.is_tag_safe()
|
||||
|
||||
|
||||
@ -289,7 +289,9 @@ class OptimizerVariable(UserDefinedObjectVariable):
|
||||
params_vt = group_vt.getitem_const(tx, ConstantVariable.create("params"))
|
||||
all_static = True
|
||||
non_static_grads = []
|
||||
for p, p_vt in zip(group["params"], params_vt.unpack_var_sequence(tx)):
|
||||
for p_ind, (p, p_vt) in enumerate(
|
||||
zip(group["params"], params_vt.unpack_var_sequence(tx))
|
||||
):
|
||||
param_source = p_vt.source
|
||||
self.tensor_to_source[p] = param_source
|
||||
grad_source = GradSource(
|
||||
@ -320,12 +322,12 @@ class OptimizerVariable(UserDefinedObjectVariable):
|
||||
|
||||
# We have to again iterate over the state dict to collect the
|
||||
# tensor_to_source dict. This is used for the finalizer.
|
||||
for idx, value in enumerate(self.value.state.values()):
|
||||
for idx, (p, value) in enumerate(self.value.state.items()):
|
||||
p_state_source = DictGetItemSource(
|
||||
state_source, ConstDictKeySource(state_source, idx)
|
||||
)
|
||||
tx.output.guard_on_key_order.add(p_state_source)
|
||||
for inner_idx, v in enumerate(value.values()):
|
||||
for inner_idx, (k, v) in enumerate(value.items()):
|
||||
if (
|
||||
isinstance(v, torch.Tensor)
|
||||
and v not in self.grad_to_source
|
||||
|
||||
@ -240,7 +240,7 @@ def run_functionalized_fw_and_collect_metadata(
|
||||
|
||||
# Inspect the state of the input tensor functional wrapper to detect input mutation info
|
||||
# If inp[i] has a metadata-only mutation, then maybe_inputs_with_mutated_metadata[i] contains the updated version
|
||||
for arg, f_arg in zip(flat_args, flat_f_args):
|
||||
for i, (arg, f_arg) in enumerate(zip(flat_args, flat_f_args)):
|
||||
# NB: Mutation of non-contiguous tensor subclass input can result in a mismatch in
|
||||
# strides between the functionalized arg inner tensors and non-functionalized arg inner
|
||||
# tensors. This is a problem as the inner tensor stride change may not be reflected
|
||||
|
||||
@ -2041,7 +2041,7 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
|
||||
|
||||
assert len(meta.attrs) == len(runtime_subclass_keys)
|
||||
leaves = []
|
||||
for attr, attr_meta in meta.attrs.items():
|
||||
for i, (attr, attr_meta) in enumerate(meta.attrs.items()):
|
||||
elem = getattr(x, attr)
|
||||
new_elem, elem_leaves = AOTDispatchAutograd.process_runtime_tangent(
|
||||
elem, attr_meta
|
||||
|
||||
@ -98,7 +98,7 @@ def unwrap_tensor_subclass_parameters(module: torch.nn.Module) -> torch.nn.Modul
|
||||
module, name, UnwrapTensorSubclass()
|
||||
)
|
||||
|
||||
for child in module.children():
|
||||
for name, child in module.named_children():
|
||||
unwrap_tensor_subclass_parameters(child)
|
||||
|
||||
return module
|
||||
|
||||
@ -1481,7 +1481,9 @@ def functionalize_rng_ops(
|
||||
)
|
||||
)
|
||||
|
||||
for rng_count, node_pair in enumerate(recomputable_rng_ops_map.values()):
|
||||
for rng_count, (base_node, node_pair) in enumerate(
|
||||
recomputable_rng_ops_map.items()
|
||||
):
|
||||
# Step 2 - Modify the fwd pass such that
|
||||
fw_node = node_pair["fwd"]
|
||||
bw_node = node_pair["bwd"]
|
||||
@ -2712,7 +2714,9 @@ def thread_graphsafe_rng_from_hops(module, is_backward):
|
||||
subgraph = getattr(module, hop_node.args[0].target)
|
||||
if isinstance(subgraph, fx.GraphModule):
|
||||
new_rng_inputs = []
|
||||
for placeholder_node in subgraph.graph.find_nodes(op="placeholder"):
|
||||
for idx, placeholder_node in enumerate(
|
||||
subgraph.graph.find_nodes(op="placeholder")
|
||||
):
|
||||
if rng_string in placeholder_node.name:
|
||||
# Found a rng state placeholder in the hop graph, lets add
|
||||
# the corresponding node in the outer graph
|
||||
|
||||
@ -116,7 +116,7 @@ def temporarily_restore_interpreter_stack(stack):
|
||||
pushed.append(s)
|
||||
yield
|
||||
finally:
|
||||
for _ in reversed(pushed):
|
||||
for s in reversed(pushed):
|
||||
# TODO: would be nice to assert that the layers are the same, but
|
||||
# Python object identity is not preserved
|
||||
pop_dynamic_layer_stack()
|
||||
|
||||
@ -907,7 +907,7 @@ def diff_tensor_meta(
|
||||
try:
|
||||
if val1 != val2:
|
||||
pair_diffs.append(f"'{meta_name}: {val1} vs {val2}'")
|
||||
except GuardOnDataDependentSymNode:
|
||||
except GuardOnDataDependentSymNode as _:
|
||||
pair_diffs.append(f"'{meta_name}: {val1} vs {val2}'")
|
||||
continue
|
||||
return pair_diffs
|
||||
@ -1197,7 +1197,7 @@ def materialize_callable_in_args(op: HopInstance, args, kwargs):
|
||||
|
||||
# call_op preserves ordering of proxies via schema
|
||||
materialized_args = []
|
||||
for i, proxy in enumerate(arg_proxies):
|
||||
for i, (proxy, arg) in enumerate(zip(arg_proxies, schema.arguments)):
|
||||
if (
|
||||
isinstance(proxy, torch.fx.Node)
|
||||
and proxy.op == "get_attr"
|
||||
|
||||
@ -316,7 +316,7 @@ def while_loop_dense(
|
||||
|
||||
if stack_output:
|
||||
outs: list[torch.Tensor] = []
|
||||
for out in outputs:
|
||||
for i, out in enumerate(outputs):
|
||||
outs.append(torch.stack(out, dim=0))
|
||||
return tuple(outs)
|
||||
|
||||
|
||||
@ -2606,7 +2606,7 @@ def custom_op_wrapper(op: str, *args: Any) -> list[c_void_p] | c_void_p | None:
|
||||
if isinstance(result, (list, tuple)):
|
||||
# unsafe_alloc_void_ptrs_from_tensors expects result contains tensor only
|
||||
result = [torch.tensor([]) if r is None else r for r in result]
|
||||
for r in result:
|
||||
for i, r in enumerate(result):
|
||||
assert isinstance(r, torch.Tensor), op + " returns a list of non-tensors"
|
||||
return torch._C._aoti.unsafe_alloc_void_ptrs_from_tensors(result) # type: ignore[arg-type]
|
||||
|
||||
|
||||
@ -895,7 +895,7 @@ class MetalKernel(SIMDKernel):
|
||||
else:
|
||||
dtype_str = self.dtype_to_str(dtype)
|
||||
code.writeline(f"constant {dtype_str}* {inner},")
|
||||
for inner in self.args.sizevars.values():
|
||||
for outer, inner in self.args.sizevars.items():
|
||||
code.writeline(f"constant long& {inner},")
|
||||
|
||||
# Write dynamic values as inputs
|
||||
|
||||
@ -218,7 +218,7 @@ class MultiKernel:
|
||||
# the multi call kernel.
|
||||
multi_call_args = call_args
|
||||
multi_call_arg_types = arg_types
|
||||
for kernel in self.kernels:
|
||||
for i, kernel in enumerate(self.kernels):
|
||||
additional_call_args, additional_arg_types = (
|
||||
kernel.additional_call_args_and_types()
|
||||
)
|
||||
|
||||
@ -717,7 +717,7 @@ class ComboKernel(Kernel):
|
||||
self, name: str, call_args: list[Any], arg_types: list[Any]
|
||||
) -> None:
|
||||
for num, sub_kernel in enumerate(self.sub_kernels):
|
||||
for tree in sub_kernel.range_trees:
|
||||
for i, tree in enumerate(sub_kernel.range_trees):
|
||||
numel_name = f"{tree.prefix}numel_{num}"
|
||||
if numel_name not in self.dynamic_shape_args:
|
||||
continue
|
||||
@ -735,7 +735,7 @@ class ComboKernel(Kernel):
|
||||
def kernel_benchmark_extra_args(self) -> list[str]:
|
||||
extra_args = []
|
||||
for num, sub_kernel in enumerate(self.sub_kernels):
|
||||
for tree in sub_kernel.range_trees:
|
||||
for i, tree in enumerate(sub_kernel.range_trees):
|
||||
numel_name = f"{tree.prefix}numel_{num}"
|
||||
if numel_name not in self.dynamic_shape_args:
|
||||
continue
|
||||
@ -1018,7 +1018,7 @@ class ComboKernel(Kernel):
|
||||
|
||||
for num, sub_kernel in enumerate(self.sub_kernels):
|
||||
meta[f"no_x_dim_{num}"] = sub_kernel.no_x_dim
|
||||
for tree in sub_kernel.range_trees:
|
||||
for i, tree in enumerate(sub_kernel.range_trees):
|
||||
# pyrefly: ignore [missing-argument]
|
||||
if not tree.is_reduction:
|
||||
numel_name = f"{tree.prefix}numel_{num}"
|
||||
|
||||
@ -3600,12 +3600,16 @@ class PythonWrapperCodegen(CodeGen):
|
||||
self.writeline("if not should_loop:")
|
||||
if stack_output:
|
||||
# Handle the case when loop never executes
|
||||
for i, carried_input in enumerate(outer_carried_inputs):
|
||||
for i, (carried_input, carried_buf) in enumerate(
|
||||
zip(outer_carried_inputs, while_loop.carried_inputs)
|
||||
):
|
||||
self.writeline(EnterSubgraphLine(self, while_loop.body_subgraph.graph))
|
||||
self.writeline(f"{name}[{i}] = {carried_input}.unsqueeze(0).clone()")
|
||||
self.writeline(ExitSubgraphLine(self))
|
||||
else:
|
||||
for i, carried_input in enumerate(outer_carried_inputs):
|
||||
for i, (carried_input, carried_buf) in enumerate(
|
||||
zip(outer_carried_inputs, while_loop.carried_inputs)
|
||||
):
|
||||
self.writeline(EnterSubgraphLine(self, while_loop.body_subgraph.graph))
|
||||
self.writeline(f"{name}[{i}] = {carried_input}.clone()")
|
||||
self.writeline(ExitSubgraphLine(self))
|
||||
|
||||
@ -424,7 +424,10 @@ def _reorder_communication_preserving_peak_memory_internal(
|
||||
return
|
||||
|
||||
# Candidate becomes last use of some bufs
|
||||
for bufs in group_n_to_bufs_after_swap_dealloc_by_candidate.values():
|
||||
for (
|
||||
gn,
|
||||
bufs,
|
||||
) in group_n_to_bufs_after_swap_dealloc_by_candidate.items():
|
||||
for buf in bufs:
|
||||
buf_to_snode_last_use[buf] = candidate
|
||||
|
||||
@ -837,7 +840,7 @@ def _schedule_for_comm(
|
||||
else:
|
||||
schedule(snode)
|
||||
|
||||
for deps in unmet_deps.values():
|
||||
for snode, deps in unmet_deps.items():
|
||||
assert len(deps) == 0, (
|
||||
f"Detected unscheduled nodes. Nodes with unmet dependencies: {unmet_deps}"
|
||||
)
|
||||
@ -1549,8 +1552,11 @@ Graph: {graph}
|
||||
node.args = new_args
|
||||
|
||||
# Delete `fsdp.copy_(unsharded_param, Y)` nodes
|
||||
for fsdp_copy_node_idxes in unsharded_param_to_fsdp_copy_node_idxes.values():
|
||||
for fsdp_copy_node_idx in fsdp_copy_node_idxes:
|
||||
for (
|
||||
unsharded_param,
|
||||
fsdp_copy_node_idxes,
|
||||
) in unsharded_param_to_fsdp_copy_node_idxes.items():
|
||||
for i, fsdp_copy_node_idx in enumerate(fsdp_copy_node_idxes):
|
||||
fsdp_copy_node = node_list[fsdp_copy_node_idx]
|
||||
graph.erase_node(fsdp_copy_node)
|
||||
|
||||
|
||||
@ -46,7 +46,7 @@ def _debug_iterative_memory_recompute(
|
||||
if iter_cm != new_cm:
|
||||
log = "ITERATIVE CURR MEMORY CANDIDATE DOES NOT MATCH"
|
||||
iterative_recompute_error = True
|
||||
for gn in gns:
|
||||
for i, gn in enumerate(gns):
|
||||
iter_gnm = iter_curr_memory[gn]
|
||||
new_gnm = est_curr_memory[gn]
|
||||
if iter_gnm != new_gnm:
|
||||
@ -65,7 +65,7 @@ def _debug_iterative_memory_recompute(
|
||||
f"\nCANDIDATE_NEW_ALLOCFREE:{snodes_allocfree[candidate]}"
|
||||
)
|
||||
peak_log = ""
|
||||
for i, (pre, _post) in enumerate(snodes_curr_memory):
|
||||
for i, (pre, post) in enumerate(snodes_curr_memory):
|
||||
if est_peak_memory == pre:
|
||||
n = snodes[i]
|
||||
peak_log = (
|
||||
|
||||
@ -454,7 +454,7 @@ def decompose_map_to_while_loop(gm: torch.fx.GraphModule):
|
||||
|
||||
graph_pass.apply(gm)
|
||||
|
||||
for _node in gm.graph.find_nodes(
|
||||
for node in gm.graph.find_nodes(
|
||||
op="call_function", target=torch.ops.higher_order.map_impl
|
||||
):
|
||||
raise AssertionError("map is not lowered to while_loop")
|
||||
@ -666,7 +666,7 @@ def decompose_scan_to_while_loop(gm: torch.fx.GraphModule):
|
||||
|
||||
graph_pass.apply(gm)
|
||||
|
||||
for _node in gm.graph.find_nodes(
|
||||
for node in gm.graph.find_nodes(
|
||||
op="call_function", target=torch.ops.higher_order.scan
|
||||
):
|
||||
raise AssertionError("scan is not lowered to while_loop")
|
||||
@ -1265,7 +1265,7 @@ def decompose_triton_kernel_wrapper_functional(graph):
|
||||
|
||||
graph_pass.apply(graph)
|
||||
|
||||
for _ in graph.find_nodes(
|
||||
for node in graph.find_nodes(
|
||||
op="call_function",
|
||||
target=torch.ops.higher_order.triton_kernel_wrapper_functional,
|
||||
):
|
||||
|
||||
@ -8770,7 +8770,9 @@ class WhileLoop(ExternKernel):
|
||||
seen_buffers: OrderedSet[int] = OrderedSet()
|
||||
result: list[Union[IRNode, TensorBox, ShapeAsConstantBuffer]] = []
|
||||
|
||||
for original_input, unwrapped_buffer in zip(carried_inputs, unwrapped_buffers):
|
||||
for i, (original_input, unwrapped_buffer) in enumerate(
|
||||
zip(carried_inputs, unwrapped_buffers)
|
||||
):
|
||||
if id(unwrapped_buffer) in seen_buffers:
|
||||
result.append(ExternKernel.copy_input(original_input))
|
||||
else:
|
||||
|
||||
@ -743,7 +743,7 @@ class _TargetArgsExpr(_TargetExpr):
|
||||
assert len(node_items) == len(self_items)
|
||||
|
||||
m = Match(ctx, self)
|
||||
for pattern, child_node in zip(self_items, node_items):
|
||||
for i, pattern, child_node in zip(itertools.count(), self_items, node_items):
|
||||
if isinstance(pattern, PatternExpr):
|
||||
child_match = ctx.match(pattern, child_node)
|
||||
if not is_match(child_match):
|
||||
|
||||
@ -2869,7 +2869,7 @@ class Scheduler:
|
||||
# NB: None means that the dependency is on an input. Don't actually
|
||||
# generate a dependency because if we do, Inductor will start trying
|
||||
# to free the unbacked int but that's pointless
|
||||
for val in V.graph.graph_inputs.values():
|
||||
for name, val in V.graph.graph_inputs.items():
|
||||
if isinstance(val, sympy.Expr):
|
||||
for fs in val.free_symbols:
|
||||
unbacked_symbol_to_origin_node[fs] = None
|
||||
@ -3569,7 +3569,9 @@ class Scheduler:
|
||||
future_choices: list[tuple[Any, Optional[LambdaFuture], ModuleType]] = []
|
||||
for hint_override in config.multi_kernel_hints:
|
||||
choice_timings = multi_node.choice_timings(hint_override)
|
||||
for choice, _ in sorted(choice_timings.items(), key=lambda x: x[1]):
|
||||
for choice, unfused_time in sorted(
|
||||
choice_timings.items(), key=lambda x: x[1]
|
||||
):
|
||||
if not isinstance(
|
||||
choice, torch._inductor.select_algorithm.TritonTemplateCaller
|
||||
):
|
||||
|
||||
@ -425,7 +425,7 @@ def apply_var_mapping(
|
||||
new_ranges, norm_pw_vars + norm_red_vars, strict=True
|
||||
):
|
||||
range_vars = []
|
||||
for _ in range(len(new_range)):
|
||||
for i in range(len(new_range)):
|
||||
range_vars.append(flat_vars[count])
|
||||
count += 1
|
||||
|
||||
|
||||
@ -348,7 +348,7 @@ def _do_bench_using_profiling(
|
||||
]
|
||||
) as p:
|
||||
# Benchmark
|
||||
for _ in range(n_repeat):
|
||||
for i in range(n_repeat):
|
||||
# we clear the L2 cache before each run
|
||||
cache.zero_()
|
||||
# record time of `fn`
|
||||
|
||||
@ -3118,7 +3118,7 @@ def _validate_symbolic_output_for_caching(
|
||||
if is_tracing:
|
||||
# Check for SymNode types in PROXY mode - this should bypass caching
|
||||
# regardless of whether symbols are known or not
|
||||
for _ in _iterate_nodes(output):
|
||||
for node in _iterate_nodes(output):
|
||||
raise _BypassDispatchCache("Proxy mode with SymNode output")
|
||||
else:
|
||||
# Check for unrepresented symbols in tensor expressions
|
||||
|
||||
@ -137,7 +137,7 @@ def _get_logger_dict_helper(
|
||||
def get_prefix(prefix):
|
||||
return prefix if prefix == "" else prefix + "."
|
||||
|
||||
for child in mod.children():
|
||||
for name, child in mod.named_children():
|
||||
if isinstance(child, Logger):
|
||||
target_dict[get_prefix(prefix) + "stats"] = child.stats
|
||||
break
|
||||
|
||||
@ -909,7 +909,8 @@ def create_a_shadows_b(
|
||||
# is added
|
||||
prev_node_c_list = [env_c[arg.name] for arg in prev_node_b]
|
||||
|
||||
for arg_idx, prev_node_c in enumerate(prev_node_c_list):
|
||||
for arg_idx, arg in enumerate(prev_node_b):
|
||||
prev_node_c = prev_node_c_list[arg_idx]
|
||||
env_c[prev_node_c.name] = _insert_logger_after_node(
|
||||
prev_node_c,
|
||||
gm_b,
|
||||
|
||||
@ -151,6 +151,6 @@ def bias_correction(
|
||||
bias.data = updated_bias
|
||||
|
||||
# Resets the data contained in the loggers
|
||||
for submodule in quantized_model.modules():
|
||||
for name, submodule in quantized_model.named_modules():
|
||||
if isinstance(submodule, MeanShadowLogger):
|
||||
submodule.clear()
|
||||
|
||||
@ -297,7 +297,7 @@ def _get_numerical_jacobian(
|
||||
inp_indices = [
|
||||
i for i, a in enumerate(target) if is_tensor_like(a) and a.requires_grad
|
||||
]
|
||||
for inp, inp_idx in zip(_iter_tensors(target, True), inp_indices):
|
||||
for i, (inp, inp_idx) in enumerate(zip(_iter_tensors(target, True), inp_indices)):
|
||||
jacobians += [
|
||||
get_numerical_jacobian_wrt_specific_input(
|
||||
fn,
|
||||
@ -549,7 +549,7 @@ def _get_analytical_jacobian_forward_ad(
|
||||
with fwAD.dual_level():
|
||||
fw_grads = []
|
||||
dual_inputs = []
|
||||
for inp in inputs:
|
||||
for i, inp in enumerate(inputs):
|
||||
if is_tensor_like(inp) and inp.requires_grad:
|
||||
if inp.layout == torch._mkldnn: # type: ignore[attr-defined]
|
||||
raise ValueError(
|
||||
@ -1275,7 +1275,7 @@ def _test_undefined_forward_mode(func, outputs, inputs):
|
||||
tensor_indices.add(i)
|
||||
dual_inputs.append(inp)
|
||||
|
||||
for fw_grad, u in zip(fw_grads, all_u):
|
||||
for i, (fw_grad, u) in enumerate(zip(fw_grads, all_u)):
|
||||
fw_grad.copy_(u.view_as(fw_grad))
|
||||
|
||||
for idx, inp in enumerate(inputs):
|
||||
|
||||
@ -446,7 +446,43 @@ def _format_viz(data, viz_kind, device):
|
||||
)
|
||||
|
||||
|
||||
def trace_plot(data, device=None, plot_segments=False):
|
||||
def filter_alloc_free_pairs(data):
|
||||
for dev_id in range(len(data["device_traces"])):
|
||||
# set of indexes of trace events for alloc-free pairs
|
||||
filterSet = set()
|
||||
# map from addr to index of alloc event
|
||||
allocMap = {}
|
||||
# set of addrs from free_requested events
|
||||
freeRequested = set()
|
||||
for idx, event in enumerate(data["device_traces"][dev_id]):
|
||||
if event["action"] == "alloc":
|
||||
allocMap[event["addr"]] = idx
|
||||
elif event["action"] == "free_requested":
|
||||
freeRequested.add(event["addr"])
|
||||
if allocMap.get(event["addr"]) is not None:
|
||||
filterSet.add(idx)
|
||||
filterSet.add(allocMap[event["addr"]])
|
||||
allocMap.pop(event["addr"])
|
||||
elif event["action"] == "free_completed":
|
||||
if event["addr"] in freeRequested:
|
||||
freeRequested.remove(event["addr"])
|
||||
filterSet.add(idx)
|
||||
else:
|
||||
print(f"free_completed without free_requested: {event}")
|
||||
|
||||
# Remove events whose index is in filterSet
|
||||
if filterSet:
|
||||
# Create a new list excluding events with indices in filterSet
|
||||
data["device_traces"][dev_id] = [
|
||||
event
|
||||
for idx, event in enumerate(data["device_traces"][dev_id])
|
||||
if idx not in filterSet
|
||||
]
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def trace_plot(data, device=None, plot_segments=False, filter_freed=False):
|
||||
"""Generate a visualization over time of the memory usage recorded by the trace as an html file.
|
||||
|
||||
Args:
|
||||
@ -454,10 +490,15 @@ def trace_plot(data, device=None, plot_segments=False):
|
||||
device (torch.device, optional): Generate the trace for this device, needed if multiple devices have allocations.
|
||||
plot_segments (bool, optional): Plots memory returned from cudaMalloc, rather than individual allocations.
|
||||
Defaults to False.
|
||||
filter_freed (bool, optional): Filter out alloc-free paired events to only plot allocations that are not freed yet.
|
||||
Defaults to False to plot all trace events.
|
||||
|
||||
Returns:
|
||||
str: HTML of visualization
|
||||
"""
|
||||
if filter_freed:
|
||||
data = filter_alloc_free_pairs(data)
|
||||
|
||||
return _format_viz(
|
||||
data,
|
||||
"Active Memory Timeline"
|
||||
@ -698,6 +739,14 @@ if __name__ == "__main__":
|
||||
"-s", "--segments", action="store_true", help=help
|
||||
)
|
||||
|
||||
help = (
|
||||
"filter out allocation-free pairs to only visualize the allocations that are not freed yet;"
|
||||
"useful to reduce the number of events for large traces for debugging OOM"
|
||||
)
|
||||
trace_plot_a.add_argument(
|
||||
"-f", "--filter_freed", action="store_true", help=help
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
def _read(name):
|
||||
@ -734,7 +783,12 @@ if __name__ == "__main__":
|
||||
data = _read(args.input)
|
||||
_write(
|
||||
args.output,
|
||||
trace_plot(data, device=args.device, plot_segments=args.segments),
|
||||
trace_plot(
|
||||
data,
|
||||
device=args.device,
|
||||
plot_segments=args.segments,
|
||||
filter_freed=args.filter_freed,
|
||||
),
|
||||
)
|
||||
elif args.action == "segment_plot":
|
||||
data = _read(args.input)
|
||||
|
||||
@ -41,7 +41,7 @@ class _PseudoZipFile:
|
||||
|
||||
pickle.dump(entries, f, protocol=DEFAULT_PROTOCOL)
|
||||
|
||||
for data in self.records.keys():
|
||||
for key, (data, length) in self.records.items():
|
||||
if isinstance(data, bytes):
|
||||
f.write(data)
|
||||
elif isinstance(data, str):
|
||||
|
||||
@ -578,7 +578,7 @@ def _load_model_state_dict(
|
||||
assign = False
|
||||
if info.broadcast_from_rank0 or info.full_state_dict:
|
||||
devices = set()
|
||||
for value in local_state_dict.values():
|
||||
for key, value in local_state_dict.items():
|
||||
if torch.is_tensor(value) and value.dim() > 0:
|
||||
devices.add(value.device)
|
||||
# In lora state_dict, there could be multiple devices, with meta device inside.
|
||||
|
||||
@ -2087,14 +2087,14 @@ class FlatParamHandle:
|
||||
param.grad.data = view
|
||||
else:
|
||||
param.grad = view
|
||||
for (
|
||||
for i, (
|
||||
param_name,
|
||||
module,
|
||||
module_name,
|
||||
prim_param_name,
|
||||
prim_module,
|
||||
_,
|
||||
) in self.flat_param._shared_param_infos:
|
||||
) in enumerate(self.flat_param._shared_param_infos):
|
||||
_p_assert(
|
||||
hasattr(module, param_name),
|
||||
f"{module_name + '.' + param_name if module_name else param_name} is missing",
|
||||
@ -2171,8 +2171,11 @@ class FlatParamHandle:
|
||||
param.data = flat_param[offset : offset + numel_in_shard]
|
||||
if self.flat_param._shared_params is None:
|
||||
raise AssertionError("Expected _shared_params to be not None")
|
||||
for param, (param_name, module, _, prim_param_name, prim_module, _) in zip(
|
||||
self.flat_param._shared_params, self.flat_param._shared_param_infos
|
||||
for i, (
|
||||
param,
|
||||
(param_name, module, _, prim_param_name, prim_module, _),
|
||||
) in enumerate(
|
||||
zip(self.flat_param._shared_params, self.flat_param._shared_param_infos)
|
||||
):
|
||||
self._setattr_param(module, param_name, param)
|
||||
prim_param = getattr(prim_module, prim_param_name)
|
||||
@ -2385,14 +2388,14 @@ class FlatParamHandle:
|
||||
|
||||
# TODO: If we want to handle shared parameters, we need to re-generate
|
||||
# the shared parameter data structures in case sharedness changed.
|
||||
for (
|
||||
for i, (
|
||||
param_name,
|
||||
module,
|
||||
_,
|
||||
prim_param_name,
|
||||
prim_module,
|
||||
_,
|
||||
) in flat_param._shared_param_infos:
|
||||
) in enumerate(flat_param._shared_param_infos):
|
||||
if getattr(module, param_name) is not getattr(prim_module, prim_param_name):
|
||||
raise NotImplementedError(
|
||||
"Changing shared parameters is not supported yet"
|
||||
|
||||
@ -924,7 +924,7 @@ class Pipe(torch.nn.Module):
|
||||
pass
|
||||
|
||||
# This is done by (1) `_sink_params` at each submodule;
|
||||
for submod in split.children():
|
||||
for name, submod in split.named_children():
|
||||
if isinstance(submod, fx.GraphModule):
|
||||
_sink_params(submod, inputs_to_state, [])
|
||||
submod.graph.lint()
|
||||
|
||||
@ -969,7 +969,7 @@ def distribute_module(
|
||||
if partition_fn is None:
|
||||
# if partition_fn not specified, we by default replicate
|
||||
# all module params/buffers
|
||||
for submod in module.modules():
|
||||
for name, submod in module.named_modules():
|
||||
replicate_module_params_buffers(submod, device_mesh)
|
||||
else:
|
||||
# apply partition_fun to submodules
|
||||
|
||||
@ -169,7 +169,9 @@ def gen_einsum_strategies(
|
||||
|
||||
# linearity strategy
|
||||
if linearity:
|
||||
linearity_placement_list: list[Placement] = [Partial()] * (len(input_dims) + 1)
|
||||
linearity_placement_list: list[Placement] = [Partial()]
|
||||
for input_dim in input_dims:
|
||||
linearity_placement_list.append(Partial())
|
||||
strategies_over_one_mesh_dim.append(linearity_placement_list)
|
||||
|
||||
# generate strategies for entire mesh
|
||||
|
||||
@ -1333,7 +1333,7 @@ def refine_dynamic_shapes_from_suggested_fixes(
|
||||
roots.add(c.root.__name__) # type: ignore[attr-defined]
|
||||
|
||||
# check keys are existing dims or new roots
|
||||
for k in shape_fixes.keys():
|
||||
for k, c in shape_fixes.items():
|
||||
assert k in name_to_dim or k in roots
|
||||
|
||||
# cache so we don't produce multiple derived dim objects
|
||||
|
||||
@ -101,11 +101,11 @@ def broadcast_types(t1, t2):
|
||||
# We make the types the same length which is the first requirement
|
||||
# for consistency
|
||||
if s1 > s2:
|
||||
for _ in range(s1 - s2):
|
||||
for i in range(s1 - s2):
|
||||
new_t2.insert(0, 1)
|
||||
|
||||
elif s2 > s1:
|
||||
for _ in range(s2 - s1):
|
||||
for i in range(s2 - s1):
|
||||
new_t1.insert(0, 1)
|
||||
|
||||
# we replace occurrences of "1" with each tensor with
|
||||
|
||||
@ -1871,7 +1871,7 @@ def _make_user_magic(method, user_type):
|
||||
setattrs(user_type, f"__r{method_name}__", rbinary_magic_impl)
|
||||
|
||||
|
||||
for method in magic_methods.keys(): # type: ignore[assignment]
|
||||
for method, func in magic_methods.items(): # type: ignore[assignment]
|
||||
if method in only_bool_magic_methods:
|
||||
_make_user_magic(method, SymBool)
|
||||
continue
|
||||
|
||||
@ -3342,7 +3342,7 @@ class DimConstraints:
|
||||
# alter derivations that depend on old root, to unify to new root
|
||||
# e.g. dx=3*_dx+1, dy=dx+1 -> dy=3*_dx+2
|
||||
for old_root in introduced_roots.values():
|
||||
for c in results.values():
|
||||
for k, c in list(results.items()):
|
||||
if (
|
||||
"eq" in c
|
||||
and isinstance(c["eq"], sympy.Expr)
|
||||
|
||||
@ -1066,7 +1066,7 @@ def call_prepare_scriptable_func_impl(obj, memo):
|
||||
else:
|
||||
new_obj_dict[name] = sub_module
|
||||
|
||||
for v in new_obj_dict.values():
|
||||
for k, v in new_obj_dict.items():
|
||||
obj.__dict__[name] = v
|
||||
|
||||
return obj
|
||||
|
||||
@ -6099,7 +6099,7 @@ def index_add(g: jit_utils.GraphContext, self, dim, index, other, alpha=None):
|
||||
|
||||
if other_dim_rank != self_dim_rank:
|
||||
delta = self_dim_rank - other_dim_rank
|
||||
for _ in range(delta):
|
||||
for i in range(delta):
|
||||
other = symbolic_helper._unsqueeze_helper(
|
||||
g, other, [symbolic_helper._get_tensor_rank(other)]
|
||||
)
|
||||
@ -6126,10 +6126,10 @@ def index_add(g: jit_utils.GraphContext, self, dim, index, other, alpha=None):
|
||||
)
|
||||
other = expand_as(g, other, new_shape)
|
||||
|
||||
for _ in range(dim):
|
||||
for i in range(dim):
|
||||
index = symbolic_helper._unsqueeze_helper(g, index, [0])
|
||||
|
||||
for _ in range(self_dim_rank - dim - 1):
|
||||
for i in range(self_dim_rank - dim - 1):
|
||||
index = symbolic_helper._unsqueeze_helper(
|
||||
g, index, [symbolic_helper._get_tensor_rank(index)]
|
||||
)
|
||||
|
||||
@ -78,7 +78,7 @@ class EncodedAttrs:
|
||||
attr_floats=[],
|
||||
attr_strs=[],
|
||||
)
|
||||
for k, v in attrs.items():
|
||||
for i, (k, v) in enumerate(attrs.items()):
|
||||
encoded.attr_keys.append(k)
|
||||
if isinstance(v, int):
|
||||
start_pos = len(encoded.attr_ints)
|
||||
|
||||
@ -445,9 +445,11 @@ def sample_inputs_batch_norm(op_info, device, dtype, requires_grad, **kwargs):
|
||||
)
|
||||
|
||||
# Checking for permutations of weights and biases as `None`
|
||||
weights = [channels, None, None]
|
||||
biases = [None, channels, None]
|
||||
is_training = [True, False, False]
|
||||
|
||||
for training in is_training:
|
||||
for weight, bias, training in zip(weights, biases, is_training, strict=True):
|
||||
yield SampleInput(
|
||||
make_arg(input_shape),
|
||||
args=(
|
||||
|
||||
@ -465,7 +465,7 @@ class DdpUnderDistAutogradTest(RpcAgentTestFixture):
|
||||
)
|
||||
|
||||
# Destroy process groups
|
||||
for trainer_rref in trainer_rrefs:
|
||||
for idx, trainer_rref in enumerate(trainer_rrefs):
|
||||
_remote_method_async(Trainer.destroy_pg, trainer_rref).wait()
|
||||
|
||||
# Send shutdown signals.
|
||||
|
||||
@ -6094,7 +6094,7 @@ class DistributedTest:
|
||||
dim=1,
|
||||
).cuda(rank)
|
||||
|
||||
for _ in range(100):
|
||||
for i in range(100):
|
||||
y = model(input_var[rank].cuda(rank))
|
||||
y.mean().backward()
|
||||
|
||||
|
||||
@ -1988,7 +1988,7 @@ class DistAutogradTest(CommonDistAutogradTest):
|
||||
self.assertEqual(self.world_size - 1, len(known_context_ids))
|
||||
|
||||
t1 = torch.rand((3, 3), requires_grad=True)
|
||||
for _ in range(100):
|
||||
for i in range(100):
|
||||
dst = self._next_rank()
|
||||
t1 = rpc.rpc_sync(worker_name(dst), torch.add, args=(t1, t1))
|
||||
|
||||
|
||||
@ -823,7 +823,7 @@ if has_triton():
|
||||
mask = offsets < n_elements
|
||||
x = tl.load(in_ptr0 + offsets, mask=mask)
|
||||
y = tl.load(in_ptr1 + offsets, mask=mask)
|
||||
for _ in range(2):
|
||||
for i in range(2):
|
||||
output = x + y
|
||||
tl.store(out_ptr + offsets, output, mask=mask)
|
||||
i = 2
|
||||
|
||||
@ -355,7 +355,7 @@
|
||||
"dp = dp.shuffle()\n",
|
||||
"dp = dp.batch(2)\n",
|
||||
"print(\"Iterate over DataFrame batches\")\n",
|
||||
"for v in dp:\n",
|
||||
"for i,v in enumerate(dp):\n",
|
||||
" print(v)\n",
|
||||
"\n",
|
||||
"# this is similar to batching of regular DataPipe\n",
|
||||
|
||||
Reference in New Issue
Block a user