mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Misc visibility changes for compiled autograd (#105298)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105298 Approved by: https://github.com/albanD, https://github.com/soulitzer
This commit is contained in:
committed by
PyTorch MergeBot
parent
cf404a8ce4
commit
e9fd815226
@ -45,7 +45,7 @@ namespace autograd {
|
||||
static constexpr int MAX_DEPTH = 60;
|
||||
|
||||
void set_device(int device);
|
||||
void validate_outputs(
|
||||
TORCH_API void validate_outputs(
|
||||
const edge_list& edges,
|
||||
variable_list& grads,
|
||||
const std::function<std::string(const std::string&)>& format_error);
|
||||
|
@ -230,6 +230,11 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
|
||||
return input_metadata_[index];
|
||||
}
|
||||
|
||||
// Danger: not thread safe, caller must protect with lock
|
||||
InputMetadata& mutable_input_metadata(size_t index) {
|
||||
return input_metadata_[index];
|
||||
}
|
||||
|
||||
/**
|
||||
* Note: Function Streams
|
||||
* A function's stream (for a given device type) is the stream of the first
|
||||
|
@ -26,7 +26,7 @@ struct InputBuffer {
|
||||
// Accumulates the variable at a specified index.
|
||||
// The optional CUDA streams determine which stream the accumulation
|
||||
// is run on and how the addition is synchronized.
|
||||
void add(
|
||||
TORCH_API void add(
|
||||
size_t pos,
|
||||
Variable&& var,
|
||||
const c10::optional<c10::Stream>& opt_producer_stream,
|
||||
|
@ -140,10 +140,21 @@ struct InputMetadata {
|
||||
return was_default_constructed_;
|
||||
}
|
||||
|
||||
private:
|
||||
bool is_nested_tensor() const {
|
||||
return (c10::holds_alternative<at::Tensor>(shape_));
|
||||
}
|
||||
|
||||
c10::SymIntArrayRef shape_as_dim_vector() const {
|
||||
const auto& dim_shape = c10::get<SymIntSmallVec>(shape_);
|
||||
return c10::SymIntArrayRef(dim_shape.data(), dim_shape.size());
|
||||
}
|
||||
|
||||
// Danger: not thread safe, caller must protect with lock
|
||||
SymIntSmallVec& mutable_shape_as_dim_vector() {
|
||||
return c10::get<SymIntSmallVec>(shape_);
|
||||
}
|
||||
|
||||
private:
|
||||
MetadataShape compute_variant_shape(const at::Tensor& input) {
|
||||
if (input.is_nested()) {
|
||||
auto nested_size = at::native::get_nested_sizes(input);
|
||||
@ -152,11 +163,6 @@ struct InputMetadata {
|
||||
return MetadataShape{c10::in_place_type<SymIntSmallVec>, input.sym_sizes()};
|
||||
}
|
||||
|
||||
c10::SymIntArrayRef shape_as_dim_vector() const {
|
||||
const auto& dim_shape = c10::get<SymIntSmallVec>(shape_);
|
||||
return c10::SymIntArrayRef(dim_shape.data(), dim_shape.size());
|
||||
}
|
||||
|
||||
at::Tensor shape_as_tensor() const {
|
||||
return c10::get<at::Tensor>(shape_);
|
||||
}
|
||||
|
@ -48,6 +48,10 @@ class TORCH_API SavedVariable {
|
||||
|
||||
void reset_data();
|
||||
|
||||
bool has_hooks() const {
|
||||
return (bool)hooks_;
|
||||
}
|
||||
|
||||
private:
|
||||
// This field contains either:
|
||||
// 1. the variable to save
|
||||
|
Reference in New Issue
Block a user