Compare commits

...

1 Commits

Author SHA1 Message Date
f41477f01b wip 2024-05-01 15:59:37 -07:00
8 changed files with 114 additions and 15 deletions

View File

@ -655,7 +655,7 @@ def speedup_experiment(args, model_iter_fn, model, example_inputs, **kwargs):
tolerance = args.xla_tolerance if args.trace_on_xla else 1e-4
torch._dynamo.config.repro_tolerance = tolerance
with maybe_profile(args.export_profiler_trace) as p:
with maybe_profile(args.export_profiler_trace, with_stack=args.profile_debug, record_shapes=args.profile_debug) as p:
if args.export_aot_inductor:
frozen_model_iter_fn = export_aot_inductor(
model, example_inputs, args.devices[0]
@ -688,6 +688,7 @@ def speedup_experiment(args, model_iter_fn, model, example_inputs, **kwargs):
# call mark_step between the 2 calls to make the comparison fair.
maybe_mark_step(args)
print(f"running rep {rep}")
with maybe_mark_profile(p=p, mark="actual"), maybe_enable_compiled_autograd(
args.compiled_autograd
):
@ -2085,6 +2086,8 @@ class BenchmarkRunner:
self.autocast_arg["dtype"] = amp_dtype
def init_optimizer(self, name, device, params):
self.optimizer = None
return
if device == "cuda" and self.args.training and name not in CI_SKIP_OPTIMIZER:
if (name in CI_USE_SGD and self.args.ci) or name in BENCHMARK_USE_SGD:
self.optimizer = torch.optim.SGD(params, lr=0.01, foreach=True)
@ -2706,7 +2709,11 @@ class BenchmarkRunner:
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
t0 = time.perf_counter()
for _ in range(niters):
for i in range(niters):
if mode == "dynamo":
print(f"warm up iteration {i}")
# if i == 2:
# breakpoint()
fn(model, example_inputs)
t1 = time.perf_counter()
latency = t1 - t0
@ -3210,6 +3217,11 @@ def parse_args(args=None):
action="store_true",
help="exports trace of kineto profiler",
)
parser.add_argument(
"--profile-debug",
action="store_true",
help="args.profile_debug",
)
parser.add_argument(
"--profiler-trace-name",
"--profiler_trace_name",

View File

@ -28,6 +28,9 @@ from torch.utils._traceback import CapturedTraceback
compiled_autograd_log = getArtifactLogger(__name__, "compiled_autograd")
verbose_log = getArtifactLogger(__name__, "compiled_autograd_verbose")
cached_fn = None
idx_to_del = []
def snapshot_verbose_logging_enabled():
return torch._logging._internal.log_state.is_artifact_enabled(
@ -65,6 +68,8 @@ class AutogradCompilerInstance:
return GetItemSource(LocalSource(name), idx)
def begin_capture(self, inputs: List[torch.Tensor], sizes: List[int]):
global cached_fn
cached_fn = None
counters["compiled_autograd"]["captures"] += 1
self.fx_tracer.root = torch.nn.Module()
self.fx_tracer.graph = torch.fx.Graph(tracer_cls=PythonKeyTracer)
@ -210,17 +215,58 @@ class AutogradCompilerInstance:
self.fx_tracer.root, self.fx_tracer.graph, "CompiledAutograd"
)
set_locals_to_steal(graph, ["inputs"])
compiled_autograd_log.info(
"%s", lazy_format_graph_code("Compiled autograd graph", graph)
)
verbose_log.debug(
"%s", lazy_format_graph_code("Compiled autograd graph", graph)
)
# compiled_autograd_log.info(
# "%s", lazy_format_graph_code("Compiled autograd graph", graph)
# )
# verbose_log.debug(
# "%s", lazy_format_graph_code("Compiled autograd graph", graph)
# )
trace_structured(
"compiled_autograd_graph",
payload_fn=lambda: graph.print_readable(print_output=False),
)
return self.compiler_fn(graph)
def wrapper(inputs,sizes,hooks):
global cached_fn
global idx_to_del
global dummy_tensors
if not cached_fn:
idx_to_scalarify = []
for i in range(len(inputs)):
inp = inputs[i]
assert isinstance(inp, torch.Tensor)
if inp.device.type == "cuda" or len(inp.size()) > 0:
# not cpu scalars
continue
# scalars
assert inp.dtype == torch.int64
idx_to_scalarify.append(i)
compiled_autograd_log.info(f"idx_to_scalarify={idx_to_scalarify}")
nodes = [node for node in graph.graph.nodes]
for i in idx_to_scalarify:
node = nodes[i+3] # 3 to offset for inputs, sizes, hooks
for user in list(node.users.keys()):
# bake into graph
user.replace_input_with(node, inputs[i])
# node._remove_from_list()
# compiled_autograd_log.info(
# "%s", lazy_format_graph_code("Compiled autograd graph", graph)
# )
cached_fn = self.compiler_fn(graph)
idx_to_del = idx_to_scalarify
dummy_tensors = [torch.tensor(1, device="cuda") for _ in range(len(idx_to_del))]
for i,idx in enumerate(idx_to_del):
inputs[idx] = dummy_tensors[i]
return cached_fn(inputs, sizes, hooks)
return wrapper
def reorder_accumulate_grad_nodes(self):
"""

View File

@ -103,7 +103,8 @@ def reduce_to_scalar_loss(out):
"""Reduce the output of a model to get scalar loss"""
if isinstance(out, torch.Tensor):
# Mean does not work on integer tensors
return out.sum() / torch.tensor(out.numel(), device=out.device)
# return out.sum() / torch.tensor(out.numel(), device=out.device)
return out.sum() / out.numel()
elif isinstance(out, (list, tuple)):
return sum(reduce_to_scalar_loss(x) for x in out) / len(out)
elif type(out).__name__ in (

View File

@ -2610,13 +2610,13 @@ def get_first_attr(obj, *attrs):
@contextlib.contextmanager
def maybe_enable_compiled_autograd(should_enable):
def maybe_enable_compiled_autograd(should_enable, fullgraph=False, dynamic=False):
def compiler_fn(gm):
def inner_compiler(gm_, example_inputs_):
torch._dynamo.utils.counters["compiled_autograd"]["compiles"] += 1
return torch._inductor.compile(gm_, example_inputs_)
return torch.compile(gm, backend=inner_compiler, fullgraph=True, dynamic=True)
return torch.compile(gm, backend=inner_compiler, fullgraph=fullgraph, dynamic=dynamic)
if should_enable:
with torch._dynamo.compiled_autograd.enable(compiler_fn) as ctx:

View File

@ -773,6 +773,7 @@ class CUDAGraphNode:
]
self.tensor_weakrefs: OutputList[Optional[TensorWeakRef]] = []
# need to manage these
# tensors which are outputs of previous graphs in the tree
self.cudagraph_managed_idxs: List[int] = [
idx
@ -780,9 +781,11 @@ class CUDAGraphNode:
if isinstance(t, torch.Tensor) and self._is_cuda_graph_recorded_tensor(t)
]
# relies on wrapped_function.static_input_idxs
self.static_input_idxs: List[int] = list(
set(wrapped_function.static_input_idxs) | set(self.cudagraph_managed_idxs)
)
log.info(f"cudagraph_managed_idxs={len(self.cudagraph_managed_idxs)}, static_input_idxs={len(self.static_input_idxs)}")
self.non_static_input_idx: LevelList[int] = [
i for i in range(len(inputs)) if i not in self.static_input_idxs
@ -925,8 +928,10 @@ class CUDAGraphNode:
self.graph.replay()
def _copy_inputs_and_remove_from_src(self, dsts, srcs):
log.info(f"_copy_inputs_and_remove_from_src srcs={len(srcs)}, dsts={len(dsts)}")
dst_tensors = []
src_tensors = []
log.info(f"non_static_input_idx={self.non_static_input_idx}")
for idx in self.non_static_input_idx:
if not isinstance(srcs[idx], torch.Tensor):
continue
@ -934,8 +939,9 @@ class CUDAGraphNode:
dst_tensors.append(index_expanded_dims(dsts[idx], expanded_dims))
src_tensors.append(index_expanded_dims(srcs[idx], expanded_dims))
srcs[idx] = None
log.info(f"dst_tensors={len(dst_tensors)}, src_tensors={len(src_tensors)}")
# Fails on empty lists
if dst_tensors:
if dst_tensors and not torch._dynamo.compiled_autograd.compiled_autograd_enabled:
torch._foreach_copy_(dst_tensors, src_tensors)
def check_static_inputs_are_stable(self, new_inputs):
@ -1070,6 +1076,7 @@ class CUDAGraphNode:
return output_storages
def run_graph(self):
print("Cudagraph tree being replayed")
assert self.graph is not None
self.graph.replay()
@ -1785,6 +1792,7 @@ class CUDAGraphTreeManager:
self.running_forwards_with_pending_backwards = False
def run(self, new_inputs: List[Tensor], function_id: FunctionID):
print(f"cudagraph tree manager running function {function_id}")
assert self.graph is not None, "Running CUDAGraph after shutdown"
out = self._run(new_inputs, function_id)

View File

@ -1580,7 +1580,7 @@ class GraphLowering(torch.fx.Interpreter):
log_module_code(mod.__file__)
log.debug("Output code written to: %s", mod.__file__)
output_code_log.debug("Output code: \n%s", code)
# output_code_log.debug("Output code: \n%s", code)
trace_structured(
"inductor_output_code",
lambda: {"filename": mod.__file__},

View File

@ -11,6 +11,8 @@
#include <torch/csrc/utils/torch_dispatch_mode.h>
#include <typeindex>
#include <vector>
#include "c10/core/DeviceType.h"
#include <iostream>
// see [Note: Compiled Autograd]
@ -140,6 +142,9 @@ struct TensorArgs {
}
TensorArg& add(const at::Tensor& tensor) {
if (tensor.defined() && tensor.device() == c10::kCPU) {
std::cout << "lifted cpu tensor, emplacing at idx=" << inputs.size() << std::endl;
}
return lookup(tensor, true);
}

View File

@ -369,8 +369,10 @@ CacheNode* _compiled_autograd_impl(
CacheKey key = node_args.key();
if (is_verbose_logging_enabled &&
cache->lookup(key, /*create=*/false) == nullptr) {
vcout() << "Creating cache entry for " << fn->name()
vcout() << "cache miss for " << fn->name()
<< ", with key of size " << key.key_size << std::endl;
} else {
std::cout << "cache hit for " << fn->name() << std::endl;
}
cache = cache->lookup(key);
}
@ -517,6 +519,31 @@ CacheNode* _compiled_autograd_impl(
}
}
// std::vector<at::Tensor> inputs = compiler_call.tensor_args.inputs;
// PyObject* pyinput = PyList_New(static_cast<Py_ssize_t>(inputs.size()));
// for (const auto i : c10::irange(inputs.size())) {
// auto& t = inputs[i];
// if (t.device() == at::kCPU) {
// if (t.sizes().size() == 0) {
// // cpu scalar
// std::cout << "inputs " << i << " is a scalar, calling .item on dtype=" << t.dtype() << std::endl;
// if (t.item().isIntegral(false)) {
// PyList_SET_ITEM(pyinput, i, PyLong_FromSsize_t(t.item().toLong()));
// } else {
// std::cout << "was not a integral" << t.item().type() << std::endl;
// }
// } else {
// // cpu non scalar
// std::cout << "inputs " << i << " is not a scalar, moving to CUDA" << std::endl;
// PyList_SET_ITEM(pyinput, i, THPVariable_Wrap(t.to(at::kCUDA)));
// }
// } else {
// // cuda
// PyList_SET_ITEM(pyinput, i, THPVariable_Wrap(t));
// }
// }
// *graph_arg_inputs = pyinput;
*graph_arg_inputs = THPVariable_WrapList(compiler_call.tensor_args.inputs);
*graph_arg_sizes = wrap_int_list(compiler_call.dyn_size_inputs);
*graph_arg_hooks = convert_hook_list(compiler_call.hooks);