[ghstack-poisoned]
This commit is contained in:
Nikita Vedeneev
2025-10-29 13:16:05 +00:00
51 changed files with 262 additions and 156 deletions

View File

@ -1,88 +1,78 @@
#include <ATen/cuda/CUDAGreenContext.h>
#if defined(CUDA_VERSION) && !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
#include <c10/cuda/driver_api.h>
#include <stdexcept>
#include <vector>
#define HAS_CUDA_GREEN_CONTEXT() 1
#else
#define HAS_CUDA_GREEN_CONTEXT() 0
#endif
namespace at::cuda {
GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
#if CUDA_HAS_GREEN_CONTEXT
int driver_version;
C10_CUDA_CHECK(cudaDriverGetVersion(&driver_version));
TORCH_CHECK(
driver_version >= 12080, "cuda driver too old to use green context!");
CUcontext pctx = nullptr;
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&pctx));
if (C10_UNLIKELY(!pctx)) {
TORCH_WARN(
"Attempted to create a green context but"
" there was no primary context! Creating a primary context...");
GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
#if HAS_CUDA_GREEN_CONTEXT()
int driver_version;
C10_CUDA_CHECK(cudaDriverGetVersion(&driver_version));
TORCH_CHECK(
driver_version >= 12080, "cuda driver too old to use green context!");
CUcontext pctx = nullptr;
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&pctx));
if (C10_UNLIKELY(!pctx)) {
TORCH_WARN(
"Attempted to create a green context but"
" there was no primary context! Creating a primary context...");
cudaFree(0);
}
cudaFree(0);
}
CUdevice device;
device_id_ = device_id;
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuDeviceGet_(&device, device_id));
CUdevice device;
device_id_ = device_id;
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuDeviceGet_(&device, device_id));
// Get device resources
CUdevResource device_resource;
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuDeviceGetDevResource_(
device, &device_resource, CU_DEV_RESOURCE_TYPE_SM));
// Get device resources
CUdevResource device_resource;
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuDeviceGetDevResource_(
device, &device_resource, CU_DEV_RESOURCE_TYPE_SM));
// Split resources
std::vector<CUdevResource> result(1);
auto result_data = result.data();
unsigned int nb_groups = 1;
CUdevResource remaining;
// Split resources
std::vector<CUdevResource> result(1);
auto result_data = result.data();
unsigned int nb_groups = 1;
CUdevResource remaining;
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuDevSmResourceSplitByCount_(
result_data,
&nb_groups,
&device_resource,
&remaining,
0, // default flags
num_sms));
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuDevSmResourceSplitByCount_(
result_data,
&nb_groups,
&device_resource,
&remaining,
0, // default flags
num_sms));
TORCH_CHECK(nb_groups == 1, "Failed to create single resource group");
TORCH_CHECK(nb_groups == 1, "Failed to create single resource group");
// Generate resource descriptor
CUdevResourceDesc desc;
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuDevResourceGenerateDesc_(
&desc, result_data, 1));
// Generate resource descriptor
CUdevResourceDesc desc;
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuDevResourceGenerateDesc_(
&desc, result_data, 1));
// Create green context
// CU_GREEN_CTX_DEFAULT_STREAM is required per docs:
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxCreate_(
&green_ctx_, desc, device, CU_GREEN_CTX_DEFAULT_STREAM));
// Create green context
// CU_GREEN_CTX_DEFAULT_STREAM is required per docs:
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxCreate_(
&green_ctx_, desc, device, CU_GREEN_CTX_DEFAULT_STREAM));
// Convert to regular context
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuCtxFromGreenCtx_(&context_, green_ctx_));
TORCH_CHECK(context_, "Green ctx conversion to regular ctx failed!");
// Convert to regular context
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuCtxFromGreenCtx_(&context_, green_ctx_));
TORCH_CHECK(context_, "Green ctx conversion to regular ctx failed!");
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
}
std::unique_ptr<GreenContext> GreenContext::create(
uint32_t num_sms,
std::optional<uint32_t> device_id) {
#if HAS_CUDA_GREEN_CONTEXT()
#if CUDA_HAS_GREEN_CONTEXT
if (!device_id.has_value()) {
device_id = at::cuda::current_device();
}
return std::unique_ptr<GreenContext>(new GreenContext(device_id.value(), num_sms));
return std::make_unique<GreenContext>(device_id.value(), num_sms);
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
@ -90,7 +80,7 @@ GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
// Implement move operations
GreenContext::GreenContext(GreenContext&& other) noexcept{
#if HAS_CUDA_GREEN_CONTEXT()
#if CUDA_HAS_GREEN_CONTEXT
device_id_ = std::exchange(other.device_id_, -1);
green_ctx_ = std::exchange(other.green_ctx_, nullptr);
context_ = std::exchange(other.context_, nullptr);
@ -101,7 +91,7 @@ GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
}
GreenContext& GreenContext::operator=(GreenContext&& other) noexcept{
#if HAS_CUDA_GREEN_CONTEXT()
#if CUDA_HAS_GREEN_CONTEXT
if (this != &other) {
// Clean up current resources
if (green_ctx_) {
@ -130,7 +120,7 @@ GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
}
GreenContext::~GreenContext() noexcept{
#if HAS_CUDA_GREEN_CONTEXT()
#if CUDA_HAS_GREEN_CONTEXT
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_));
#else
@ -138,9 +128,25 @@ GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
#endif
}
// Get the underlying CUDA context
CUcontext GreenContext::getContext() const {
#if CUDA_HAS_GREEN_CONTEXT
return context_;
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
}
// Get the underlying green context
#if CUDA_HAS_GREEN_CONTEXT
CUgreenCtx GreenContext::getGreenContext() const {
return green_ctx_;
}
#endif
// Make this context current
void GreenContext::setContext() {
#if HAS_CUDA_GREEN_CONTEXT()
#if CUDA_HAS_GREEN_CONTEXT
auto current_stream = c10::cuda::getCurrentCUDAStream();
parent_stream_ = current_stream.stream();
@ -169,7 +175,7 @@ GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
}
void GreenContext::popContext() {
#if HAS_CUDA_GREEN_CONTEXT()
#if CUDA_HAS_GREEN_CONTEXT
// see above note about stream being hardcoded to the default stream
at::cuda::CUDAEvent ev;
ev.record(c10::cuda::getCurrentCUDAStream());

View File

@ -1,38 +1,53 @@
#pragma once
#include <ATen/cuda/CUDAEvent.h>
#include <cuda.h>
// Forward declare green context as opaque ptr
typedef struct CUgreenCtx_st* CUgreenCtx;
#if defined(CUDA_VERSION) && !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
#include <c10/cuda/driver_api.h>
#include <cuda.h>
#include <memory>
#include <stdexcept>
#include <vector>
#define CUDA_HAS_GREEN_CONTEXT 1
#else
#define CUDA_HAS_GREEN_CONTEXT 0
#endif
namespace at::cuda {
class TORCH_CUDA_CPP_API GreenContext {
public:
// Green context creation
static std::unique_ptr<GreenContext> create(
uint32_t num_sms,
std::optional<uint32_t> device_id);
~GreenContext() noexcept;
GreenContext(uint32_t device_id, uint32_t num_sms);
static std::unique_ptr<GreenContext> create(uint32_t num_sms, std::optional<uint32_t> device_id);
// Delete copy constructor and assignment
GreenContext(const GreenContext&) = delete;
GreenContext& operator=(const GreenContext&) = delete;
// Implement move operations
GreenContext(GreenContext&& other) noexcept;
GreenContext& operator=(GreenContext&& other) noexcept;
~GreenContext() noexcept;
// Get the underlying CUDA context
CUcontext getContext() const;
// Get the underlying green context
#if CUDA_HAS_GREEN_CONTEXT
CUgreenCtx getGreenContext() const;
#endif
// Make this context current
void setContext();
void popContext();
private:
GreenContext(uint32_t device_id, uint32_t num_sms);
// Implement move operations
GreenContext(GreenContext&& other) noexcept;
GreenContext& operator=(GreenContext&& other) noexcept;
#if CUDA_HAS_GREEN_CONTEXT
int32_t device_id_ = -1;
CUgreenCtx green_ctx_ = nullptr;
CUcontext context_ = nullptr;
cudaStream_t parent_stream_ = nullptr;
#endif
};
} // namespace at::cuda

View File

@ -4222,6 +4222,7 @@ class TestCudaMallocAsync(TestCase):
ss = torch.cuda.memory._snapshot()
trace_plot(ss)
trace_plot(ss, filter_freed=True)
segment_plot(ss)
text = json.dumps(ss)

View File

@ -1791,7 +1791,7 @@ def rewrite_signature(
for i, val in enumerate(sources):
dict_of_source_vals[id(val)] = i
for val in candidates:
for i, val in enumerate(candidates):
if isinstance(val, tuple(common_constant_types)):
matched_elements_positions.append(None)
elif id(val) not in dict_of_source_vals:

View File

@ -317,7 +317,7 @@ class GuardManagerWrapper:
is_diff_guard_node = (
node.get_source() in self.diff_guard_sources or node.fail_count() > 0
)
for _idx, (key_mgr, val_mgr) in sorted(
for idx, (key_mgr, val_mgr) in sorted(
node.get_key_value_managers().items()
):
is_diff_guard_node |= visit(key_mgr) | visit(val_mgr)
@ -440,7 +440,7 @@ class GuardManagerWrapper:
is_subtree_tag_safe = True
# Recurse to get the tag safe roots from subtree.
for _idx, (key_mgr, val_mgr) in sorted(
for idx, (key_mgr, val_mgr) in sorted(
node.get_key_value_managers().items()
):
if key_mgr is not None:
@ -448,7 +448,9 @@ class GuardManagerWrapper:
if val_mgr is not None:
tag_safe_roots.extend(visit(val_mgr))
for key_mgr, val_mgr in node.get_key_value_managers().values():
for idx, (key_mgr, val_mgr) in sorted(
node.get_key_value_managers().items()
):
if key_mgr:
is_subtree_tag_safe &= key_mgr.is_tag_safe()

View File

@ -289,7 +289,9 @@ class OptimizerVariable(UserDefinedObjectVariable):
params_vt = group_vt.getitem_const(tx, ConstantVariable.create("params"))
all_static = True
non_static_grads = []
for p, p_vt in zip(group["params"], params_vt.unpack_var_sequence(tx)):
for p_ind, (p, p_vt) in enumerate(
zip(group["params"], params_vt.unpack_var_sequence(tx))
):
param_source = p_vt.source
self.tensor_to_source[p] = param_source
grad_source = GradSource(
@ -320,12 +322,12 @@ class OptimizerVariable(UserDefinedObjectVariable):
# We have to again iterate over the state dict to collect the
# tensor_to_source dict. This is used for the finalizer.
for idx, value in enumerate(self.value.state.values()):
for idx, (p, value) in enumerate(self.value.state.items()):
p_state_source = DictGetItemSource(
state_source, ConstDictKeySource(state_source, idx)
)
tx.output.guard_on_key_order.add(p_state_source)
for inner_idx, v in enumerate(value.values()):
for inner_idx, (k, v) in enumerate(value.items()):
if (
isinstance(v, torch.Tensor)
and v not in self.grad_to_source

View File

@ -240,7 +240,7 @@ def run_functionalized_fw_and_collect_metadata(
# Inspect the state of the input tensor functional wrapper to detect input mutation info
# If inp[i] has a metadata-only mutation, then maybe_inputs_with_mutated_metadata[i] contains the updated version
for arg, f_arg in zip(flat_args, flat_f_args):
for i, (arg, f_arg) in enumerate(zip(flat_args, flat_f_args)):
# NB: Mutation of non-contiguous tensor subclass input can result in a mismatch in
# strides between the functionalized arg inner tensors and non-functionalized arg inner
# tensors. This is a problem as the inner tensor stride change may not be reflected

View File

@ -2041,7 +2041,7 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
assert len(meta.attrs) == len(runtime_subclass_keys)
leaves = []
for attr, attr_meta in meta.attrs.items():
for i, (attr, attr_meta) in enumerate(meta.attrs.items()):
elem = getattr(x, attr)
new_elem, elem_leaves = AOTDispatchAutograd.process_runtime_tangent(
elem, attr_meta

View File

@ -98,7 +98,7 @@ def unwrap_tensor_subclass_parameters(module: torch.nn.Module) -> torch.nn.Modul
module, name, UnwrapTensorSubclass()
)
for child in module.children():
for name, child in module.named_children():
unwrap_tensor_subclass_parameters(child)
return module

View File

@ -1481,7 +1481,9 @@ def functionalize_rng_ops(
)
)
for rng_count, node_pair in enumerate(recomputable_rng_ops_map.values()):
for rng_count, (base_node, node_pair) in enumerate(
recomputable_rng_ops_map.items()
):
# Step 2 - Modify the fwd pass such that
fw_node = node_pair["fwd"]
bw_node = node_pair["bwd"]
@ -2712,7 +2714,9 @@ def thread_graphsafe_rng_from_hops(module, is_backward):
subgraph = getattr(module, hop_node.args[0].target)
if isinstance(subgraph, fx.GraphModule):
new_rng_inputs = []
for placeholder_node in subgraph.graph.find_nodes(op="placeholder"):
for idx, placeholder_node in enumerate(
subgraph.graph.find_nodes(op="placeholder")
):
if rng_string in placeholder_node.name:
# Found a rng state placeholder in the hop graph, lets add
# the corresponding node in the outer graph

View File

@ -116,7 +116,7 @@ def temporarily_restore_interpreter_stack(stack):
pushed.append(s)
yield
finally:
for _ in reversed(pushed):
for s in reversed(pushed):
# TODO: would be nice to assert that the layers are the same, but
# Python object identity is not preserved
pop_dynamic_layer_stack()

View File

@ -907,7 +907,7 @@ def diff_tensor_meta(
try:
if val1 != val2:
pair_diffs.append(f"'{meta_name}: {val1} vs {val2}'")
except GuardOnDataDependentSymNode:
except GuardOnDataDependentSymNode as _:
pair_diffs.append(f"'{meta_name}: {val1} vs {val2}'")
continue
return pair_diffs
@ -1197,7 +1197,7 @@ def materialize_callable_in_args(op: HopInstance, args, kwargs):
# call_op preserves ordering of proxies via schema
materialized_args = []
for i, proxy in enumerate(arg_proxies):
for i, (proxy, arg) in enumerate(zip(arg_proxies, schema.arguments)):
if (
isinstance(proxy, torch.fx.Node)
and proxy.op == "get_attr"

View File

@ -316,7 +316,7 @@ def while_loop_dense(
if stack_output:
outs: list[torch.Tensor] = []
for out in outputs:
for i, out in enumerate(outputs):
outs.append(torch.stack(out, dim=0))
return tuple(outs)

View File

@ -2606,7 +2606,7 @@ def custom_op_wrapper(op: str, *args: Any) -> list[c_void_p] | c_void_p | None:
if isinstance(result, (list, tuple)):
# unsafe_alloc_void_ptrs_from_tensors expects result contains tensor only
result = [torch.tensor([]) if r is None else r for r in result]
for r in result:
for i, r in enumerate(result):
assert isinstance(r, torch.Tensor), op + " returns a list of non-tensors"
return torch._C._aoti.unsafe_alloc_void_ptrs_from_tensors(result) # type: ignore[arg-type]

View File

@ -895,7 +895,7 @@ class MetalKernel(SIMDKernel):
else:
dtype_str = self.dtype_to_str(dtype)
code.writeline(f"constant {dtype_str}* {inner},")
for inner in self.args.sizevars.values():
for outer, inner in self.args.sizevars.items():
code.writeline(f"constant long& {inner},")
# Write dynamic values as inputs

View File

@ -218,7 +218,7 @@ class MultiKernel:
# the multi call kernel.
multi_call_args = call_args
multi_call_arg_types = arg_types
for kernel in self.kernels:
for i, kernel in enumerate(self.kernels):
additional_call_args, additional_arg_types = (
kernel.additional_call_args_and_types()
)

View File

@ -717,7 +717,7 @@ class ComboKernel(Kernel):
self, name: str, call_args: list[Any], arg_types: list[Any]
) -> None:
for num, sub_kernel in enumerate(self.sub_kernels):
for tree in sub_kernel.range_trees:
for i, tree in enumerate(sub_kernel.range_trees):
numel_name = f"{tree.prefix}numel_{num}"
if numel_name not in self.dynamic_shape_args:
continue
@ -735,7 +735,7 @@ class ComboKernel(Kernel):
def kernel_benchmark_extra_args(self) -> list[str]:
extra_args = []
for num, sub_kernel in enumerate(self.sub_kernels):
for tree in sub_kernel.range_trees:
for i, tree in enumerate(sub_kernel.range_trees):
numel_name = f"{tree.prefix}numel_{num}"
if numel_name not in self.dynamic_shape_args:
continue
@ -1018,7 +1018,7 @@ class ComboKernel(Kernel):
for num, sub_kernel in enumerate(self.sub_kernels):
meta[f"no_x_dim_{num}"] = sub_kernel.no_x_dim
for tree in sub_kernel.range_trees:
for i, tree in enumerate(sub_kernel.range_trees):
# pyrefly: ignore [missing-argument]
if not tree.is_reduction:
numel_name = f"{tree.prefix}numel_{num}"

View File

@ -3600,12 +3600,16 @@ class PythonWrapperCodegen(CodeGen):
self.writeline("if not should_loop:")
if stack_output:
# Handle the case when loop never executes
for i, carried_input in enumerate(outer_carried_inputs):
for i, (carried_input, carried_buf) in enumerate(
zip(outer_carried_inputs, while_loop.carried_inputs)
):
self.writeline(EnterSubgraphLine(self, while_loop.body_subgraph.graph))
self.writeline(f"{name}[{i}] = {carried_input}.unsqueeze(0).clone()")
self.writeline(ExitSubgraphLine(self))
else:
for i, carried_input in enumerate(outer_carried_inputs):
for i, (carried_input, carried_buf) in enumerate(
zip(outer_carried_inputs, while_loop.carried_inputs)
):
self.writeline(EnterSubgraphLine(self, while_loop.body_subgraph.graph))
self.writeline(f"{name}[{i}] = {carried_input}.clone()")
self.writeline(ExitSubgraphLine(self))

View File

@ -424,7 +424,10 @@ def _reorder_communication_preserving_peak_memory_internal(
return
# Candidate becomes last use of some bufs
for bufs in group_n_to_bufs_after_swap_dealloc_by_candidate.values():
for (
gn,
bufs,
) in group_n_to_bufs_after_swap_dealloc_by_candidate.items():
for buf in bufs:
buf_to_snode_last_use[buf] = candidate
@ -837,7 +840,7 @@ def _schedule_for_comm(
else:
schedule(snode)
for deps in unmet_deps.values():
for snode, deps in unmet_deps.items():
assert len(deps) == 0, (
f"Detected unscheduled nodes. Nodes with unmet dependencies: {unmet_deps}"
)
@ -1549,8 +1552,11 @@ Graph: {graph}
node.args = new_args
# Delete `fsdp.copy_(unsharded_param, Y)` nodes
for fsdp_copy_node_idxes in unsharded_param_to_fsdp_copy_node_idxes.values():
for fsdp_copy_node_idx in fsdp_copy_node_idxes:
for (
unsharded_param,
fsdp_copy_node_idxes,
) in unsharded_param_to_fsdp_copy_node_idxes.items():
for i, fsdp_copy_node_idx in enumerate(fsdp_copy_node_idxes):
fsdp_copy_node = node_list[fsdp_copy_node_idx]
graph.erase_node(fsdp_copy_node)

View File

@ -46,7 +46,7 @@ def _debug_iterative_memory_recompute(
if iter_cm != new_cm:
log = "ITERATIVE CURR MEMORY CANDIDATE DOES NOT MATCH"
iterative_recompute_error = True
for gn in gns:
for i, gn in enumerate(gns):
iter_gnm = iter_curr_memory[gn]
new_gnm = est_curr_memory[gn]
if iter_gnm != new_gnm:
@ -65,7 +65,7 @@ def _debug_iterative_memory_recompute(
f"\nCANDIDATE_NEW_ALLOCFREE:{snodes_allocfree[candidate]}"
)
peak_log = ""
for i, (pre, _post) in enumerate(snodes_curr_memory):
for i, (pre, post) in enumerate(snodes_curr_memory):
if est_peak_memory == pre:
n = snodes[i]
peak_log = (

View File

@ -454,7 +454,7 @@ def decompose_map_to_while_loop(gm: torch.fx.GraphModule):
graph_pass.apply(gm)
for _node in gm.graph.find_nodes(
for node in gm.graph.find_nodes(
op="call_function", target=torch.ops.higher_order.map_impl
):
raise AssertionError("map is not lowered to while_loop")
@ -666,7 +666,7 @@ def decompose_scan_to_while_loop(gm: torch.fx.GraphModule):
graph_pass.apply(gm)
for _node in gm.graph.find_nodes(
for node in gm.graph.find_nodes(
op="call_function", target=torch.ops.higher_order.scan
):
raise AssertionError("scan is not lowered to while_loop")
@ -1265,7 +1265,7 @@ def decompose_triton_kernel_wrapper_functional(graph):
graph_pass.apply(graph)
for _ in graph.find_nodes(
for node in graph.find_nodes(
op="call_function",
target=torch.ops.higher_order.triton_kernel_wrapper_functional,
):

View File

@ -8770,7 +8770,9 @@ class WhileLoop(ExternKernel):
seen_buffers: OrderedSet[int] = OrderedSet()
result: list[Union[IRNode, TensorBox, ShapeAsConstantBuffer]] = []
for original_input, unwrapped_buffer in zip(carried_inputs, unwrapped_buffers):
for i, (original_input, unwrapped_buffer) in enumerate(
zip(carried_inputs, unwrapped_buffers)
):
if id(unwrapped_buffer) in seen_buffers:
result.append(ExternKernel.copy_input(original_input))
else:

View File

@ -743,7 +743,7 @@ class _TargetArgsExpr(_TargetExpr):
assert len(node_items) == len(self_items)
m = Match(ctx, self)
for pattern, child_node in zip(self_items, node_items):
for i, pattern, child_node in zip(itertools.count(), self_items, node_items):
if isinstance(pattern, PatternExpr):
child_match = ctx.match(pattern, child_node)
if not is_match(child_match):

View File

@ -2869,7 +2869,7 @@ class Scheduler:
# NB: None means that the dependency is on an input. Don't actually
# generate a dependency because if we do, Inductor will start trying
# to free the unbacked int but that's pointless
for val in V.graph.graph_inputs.values():
for name, val in V.graph.graph_inputs.items():
if isinstance(val, sympy.Expr):
for fs in val.free_symbols:
unbacked_symbol_to_origin_node[fs] = None
@ -3569,7 +3569,9 @@ class Scheduler:
future_choices: list[tuple[Any, Optional[LambdaFuture], ModuleType]] = []
for hint_override in config.multi_kernel_hints:
choice_timings = multi_node.choice_timings(hint_override)
for choice, _ in sorted(choice_timings.items(), key=lambda x: x[1]):
for choice, unfused_time in sorted(
choice_timings.items(), key=lambda x: x[1]
):
if not isinstance(
choice, torch._inductor.select_algorithm.TritonTemplateCaller
):

View File

@ -425,7 +425,7 @@ def apply_var_mapping(
new_ranges, norm_pw_vars + norm_red_vars, strict=True
):
range_vars = []
for _ in range(len(new_range)):
for i in range(len(new_range)):
range_vars.append(flat_vars[count])
count += 1

View File

@ -348,7 +348,7 @@ def _do_bench_using_profiling(
]
) as p:
# Benchmark
for _ in range(n_repeat):
for i in range(n_repeat):
# we clear the L2 cache before each run
cache.zero_()
# record time of `fn`

View File

@ -3118,7 +3118,7 @@ def _validate_symbolic_output_for_caching(
if is_tracing:
# Check for SymNode types in PROXY mode - this should bypass caching
# regardless of whether symbols are known or not
for _ in _iterate_nodes(output):
for node in _iterate_nodes(output):
raise _BypassDispatchCache("Proxy mode with SymNode output")
else:
# Check for unrepresented symbols in tensor expressions

View File

@ -137,7 +137,7 @@ def _get_logger_dict_helper(
def get_prefix(prefix):
return prefix if prefix == "" else prefix + "."
for child in mod.children():
for name, child in mod.named_children():
if isinstance(child, Logger):
target_dict[get_prefix(prefix) + "stats"] = child.stats
break

View File

@ -909,7 +909,8 @@ def create_a_shadows_b(
# is added
prev_node_c_list = [env_c[arg.name] for arg in prev_node_b]
for arg_idx, prev_node_c in enumerate(prev_node_c_list):
for arg_idx, arg in enumerate(prev_node_b):
prev_node_c = prev_node_c_list[arg_idx]
env_c[prev_node_c.name] = _insert_logger_after_node(
prev_node_c,
gm_b,

View File

@ -151,6 +151,6 @@ def bias_correction(
bias.data = updated_bias
# Resets the data contained in the loggers
for submodule in quantized_model.modules():
for name, submodule in quantized_model.named_modules():
if isinstance(submodule, MeanShadowLogger):
submodule.clear()

View File

@ -297,7 +297,7 @@ def _get_numerical_jacobian(
inp_indices = [
i for i, a in enumerate(target) if is_tensor_like(a) and a.requires_grad
]
for inp, inp_idx in zip(_iter_tensors(target, True), inp_indices):
for i, (inp, inp_idx) in enumerate(zip(_iter_tensors(target, True), inp_indices)):
jacobians += [
get_numerical_jacobian_wrt_specific_input(
fn,
@ -549,7 +549,7 @@ def _get_analytical_jacobian_forward_ad(
with fwAD.dual_level():
fw_grads = []
dual_inputs = []
for inp in inputs:
for i, inp in enumerate(inputs):
if is_tensor_like(inp) and inp.requires_grad:
if inp.layout == torch._mkldnn: # type: ignore[attr-defined]
raise ValueError(
@ -1275,7 +1275,7 @@ def _test_undefined_forward_mode(func, outputs, inputs):
tensor_indices.add(i)
dual_inputs.append(inp)
for fw_grad, u in zip(fw_grads, all_u):
for i, (fw_grad, u) in enumerate(zip(fw_grads, all_u)):
fw_grad.copy_(u.view_as(fw_grad))
for idx, inp in enumerate(inputs):

View File

@ -446,7 +446,43 @@ def _format_viz(data, viz_kind, device):
)
def trace_plot(data, device=None, plot_segments=False):
def filter_alloc_free_pairs(data):
for dev_id in range(len(data["device_traces"])):
# set of indexes of trace events for alloc-free pairs
filterSet = set()
# map from addr to index of alloc event
allocMap = {}
# set of addrs from free_requested events
freeRequested = set()
for idx, event in enumerate(data["device_traces"][dev_id]):
if event["action"] == "alloc":
allocMap[event["addr"]] = idx
elif event["action"] == "free_requested":
freeRequested.add(event["addr"])
if allocMap.get(event["addr"]) is not None:
filterSet.add(idx)
filterSet.add(allocMap[event["addr"]])
allocMap.pop(event["addr"])
elif event["action"] == "free_completed":
if event["addr"] in freeRequested:
freeRequested.remove(event["addr"])
filterSet.add(idx)
else:
print(f"free_completed without free_requested: {event}")
# Remove events whose index is in filterSet
if filterSet:
# Create a new list excluding events with indices in filterSet
data["device_traces"][dev_id] = [
event
for idx, event in enumerate(data["device_traces"][dev_id])
if idx not in filterSet
]
return data
def trace_plot(data, device=None, plot_segments=False, filter_freed=False):
"""Generate a visualization over time of the memory usage recorded by the trace as an html file.
Args:
@ -454,10 +490,15 @@ def trace_plot(data, device=None, plot_segments=False):
device (torch.device, optional): Generate the trace for this device, needed if multiple devices have allocations.
plot_segments (bool, optional): Plots memory returned from cudaMalloc, rather than individual allocations.
Defaults to False.
filter_freed (bool, optional): Filter out alloc-free paired events to only plot allocations that are not freed yet.
Defaults to False to plot all trace events.
Returns:
str: HTML of visualization
"""
if filter_freed:
data = filter_alloc_free_pairs(data)
return _format_viz(
data,
"Active Memory Timeline"
@ -698,6 +739,14 @@ if __name__ == "__main__":
"-s", "--segments", action="store_true", help=help
)
help = (
"filter out allocation-free pairs to only visualize the allocations that are not freed yet;"
"useful to reduce the number of events for large traces for debugging OOM"
)
trace_plot_a.add_argument(
"-f", "--filter_freed", action="store_true", help=help
)
args = parser.parse_args()
def _read(name):
@ -734,7 +783,12 @@ if __name__ == "__main__":
data = _read(args.input)
_write(
args.output,
trace_plot(data, device=args.device, plot_segments=args.segments),
trace_plot(
data,
device=args.device,
plot_segments=args.segments,
filter_freed=args.filter_freed,
),
)
elif args.action == "segment_plot":
data = _read(args.input)

View File

@ -41,7 +41,7 @@ class _PseudoZipFile:
pickle.dump(entries, f, protocol=DEFAULT_PROTOCOL)
for data in self.records.keys():
for key, (data, length) in self.records.items():
if isinstance(data, bytes):
f.write(data)
elif isinstance(data, str):

View File

@ -578,7 +578,7 @@ def _load_model_state_dict(
assign = False
if info.broadcast_from_rank0 or info.full_state_dict:
devices = set()
for value in local_state_dict.values():
for key, value in local_state_dict.items():
if torch.is_tensor(value) and value.dim() > 0:
devices.add(value.device)
# In lora state_dict, there could be multiple devices, with meta device inside.

View File

@ -2087,14 +2087,14 @@ class FlatParamHandle:
param.grad.data = view
else:
param.grad = view
for (
for i, (
param_name,
module,
module_name,
prim_param_name,
prim_module,
_,
) in self.flat_param._shared_param_infos:
) in enumerate(self.flat_param._shared_param_infos):
_p_assert(
hasattr(module, param_name),
f"{module_name + '.' + param_name if module_name else param_name} is missing",
@ -2171,8 +2171,11 @@ class FlatParamHandle:
param.data = flat_param[offset : offset + numel_in_shard]
if self.flat_param._shared_params is None:
raise AssertionError("Expected _shared_params to be not None")
for param, (param_name, module, _, prim_param_name, prim_module, _) in zip(
self.flat_param._shared_params, self.flat_param._shared_param_infos
for i, (
param,
(param_name, module, _, prim_param_name, prim_module, _),
) in enumerate(
zip(self.flat_param._shared_params, self.flat_param._shared_param_infos)
):
self._setattr_param(module, param_name, param)
prim_param = getattr(prim_module, prim_param_name)
@ -2385,14 +2388,14 @@ class FlatParamHandle:
# TODO: If we want to handle shared parameters, we need to re-generate
# the shared parameter data structures in case sharedness changed.
for (
for i, (
param_name,
module,
_,
prim_param_name,
prim_module,
_,
) in flat_param._shared_param_infos:
) in enumerate(flat_param._shared_param_infos):
if getattr(module, param_name) is not getattr(prim_module, prim_param_name):
raise NotImplementedError(
"Changing shared parameters is not supported yet"

View File

@ -924,7 +924,7 @@ class Pipe(torch.nn.Module):
pass
# This is done by (1) `_sink_params` at each submodule;
for submod in split.children():
for name, submod in split.named_children():
if isinstance(submod, fx.GraphModule):
_sink_params(submod, inputs_to_state, [])
submod.graph.lint()

View File

@ -969,7 +969,7 @@ def distribute_module(
if partition_fn is None:
# if partition_fn not specified, we by default replicate
# all module params/buffers
for submod in module.modules():
for name, submod in module.named_modules():
replicate_module_params_buffers(submod, device_mesh)
else:
# apply partition_fun to submodules

View File

@ -169,7 +169,9 @@ def gen_einsum_strategies(
# linearity strategy
if linearity:
linearity_placement_list: list[Placement] = [Partial()] * (len(input_dims) + 1)
linearity_placement_list: list[Placement] = [Partial()]
for input_dim in input_dims:
linearity_placement_list.append(Partial())
strategies_over_one_mesh_dim.append(linearity_placement_list)
# generate strategies for entire mesh

View File

@ -1333,7 +1333,7 @@ def refine_dynamic_shapes_from_suggested_fixes(
roots.add(c.root.__name__) # type: ignore[attr-defined]
# check keys are existing dims or new roots
for k in shape_fixes.keys():
for k, c in shape_fixes.items():
assert k in name_to_dim or k in roots
# cache so we don't produce multiple derived dim objects

View File

@ -101,11 +101,11 @@ def broadcast_types(t1, t2):
# We make the types the same length which is the first requirement
# for consistency
if s1 > s2:
for _ in range(s1 - s2):
for i in range(s1 - s2):
new_t2.insert(0, 1)
elif s2 > s1:
for _ in range(s2 - s1):
for i in range(s2 - s1):
new_t1.insert(0, 1)
# we replace occurrences of "1" with each tensor with

View File

@ -1871,7 +1871,7 @@ def _make_user_magic(method, user_type):
setattrs(user_type, f"__r{method_name}__", rbinary_magic_impl)
for method in magic_methods.keys(): # type: ignore[assignment]
for method, func in magic_methods.items(): # type: ignore[assignment]
if method in only_bool_magic_methods:
_make_user_magic(method, SymBool)
continue

View File

@ -3342,7 +3342,7 @@ class DimConstraints:
# alter derivations that depend on old root, to unify to new root
# e.g. dx=3*_dx+1, dy=dx+1 -> dy=3*_dx+2
for old_root in introduced_roots.values():
for c in results.values():
for k, c in list(results.items()):
if (
"eq" in c
and isinstance(c["eq"], sympy.Expr)

View File

@ -1066,7 +1066,7 @@ def call_prepare_scriptable_func_impl(obj, memo):
else:
new_obj_dict[name] = sub_module
for v in new_obj_dict.values():
for k, v in new_obj_dict.items():
obj.__dict__[name] = v
return obj

View File

@ -6099,7 +6099,7 @@ def index_add(g: jit_utils.GraphContext, self, dim, index, other, alpha=None):
if other_dim_rank != self_dim_rank:
delta = self_dim_rank - other_dim_rank
for _ in range(delta):
for i in range(delta):
other = symbolic_helper._unsqueeze_helper(
g, other, [symbolic_helper._get_tensor_rank(other)]
)
@ -6126,10 +6126,10 @@ def index_add(g: jit_utils.GraphContext, self, dim, index, other, alpha=None):
)
other = expand_as(g, other, new_shape)
for _ in range(dim):
for i in range(dim):
index = symbolic_helper._unsqueeze_helper(g, index, [0])
for _ in range(self_dim_rank - dim - 1):
for i in range(self_dim_rank - dim - 1):
index = symbolic_helper._unsqueeze_helper(
g, index, [symbolic_helper._get_tensor_rank(index)]
)

View File

@ -78,7 +78,7 @@ class EncodedAttrs:
attr_floats=[],
attr_strs=[],
)
for k, v in attrs.items():
for i, (k, v) in enumerate(attrs.items()):
encoded.attr_keys.append(k)
if isinstance(v, int):
start_pos = len(encoded.attr_ints)

View File

@ -445,9 +445,11 @@ def sample_inputs_batch_norm(op_info, device, dtype, requires_grad, **kwargs):
)
# Checking for permutations of weights and biases as `None`
weights = [channels, None, None]
biases = [None, channels, None]
is_training = [True, False, False]
for training in is_training:
for weight, bias, training in zip(weights, biases, is_training, strict=True):
yield SampleInput(
make_arg(input_shape),
args=(

View File

@ -465,7 +465,7 @@ class DdpUnderDistAutogradTest(RpcAgentTestFixture):
)
# Destroy process groups
for trainer_rref in trainer_rrefs:
for idx, trainer_rref in enumerate(trainer_rrefs):
_remote_method_async(Trainer.destroy_pg, trainer_rref).wait()
# Send shutdown signals.

View File

@ -6094,7 +6094,7 @@ class DistributedTest:
dim=1,
).cuda(rank)
for _ in range(100):
for i in range(100):
y = model(input_var[rank].cuda(rank))
y.mean().backward()

View File

@ -1988,7 +1988,7 @@ class DistAutogradTest(CommonDistAutogradTest):
self.assertEqual(self.world_size - 1, len(known_context_ids))
t1 = torch.rand((3, 3), requires_grad=True)
for _ in range(100):
for i in range(100):
dst = self._next_rank()
t1 = rpc.rpc_sync(worker_name(dst), torch.add, args=(t1, t1))

View File

@ -823,7 +823,7 @@ if has_triton():
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
for _ in range(2):
for i in range(2):
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)
i = 2

View File

@ -355,7 +355,7 @@
"dp = dp.shuffle()\n",
"dp = dp.batch(2)\n",
"print(\"Iterate over DataFrame batches\")\n",
"for v in dp:\n",
"for i,v in enumerate(dp):\n",
" print(v)\n",
"\n",
"# this is similar to batching of regular DataPipe\n",