Misc visibility changes for compiled autograd (#105298)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105298
Approved by: https://github.com/albanD, https://github.com/soulitzer
This commit is contained in:
Jason Ansel
2023-07-17 14:44:22 -07:00
committed by PyTorch MergeBot
parent cf404a8ce4
commit e9fd815226
5 changed files with 23 additions and 8 deletions

View File

@ -45,7 +45,7 @@ namespace autograd {
static constexpr int MAX_DEPTH = 60;
void set_device(int device);
void validate_outputs(
TORCH_API void validate_outputs(
const edge_list& edges,
variable_list& grads,
const std::function<std::string(const std::string&)>& format_error);

View File

@ -230,6 +230,11 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
return input_metadata_[index];
}
// Danger: not thread safe, caller must protect with lock
InputMetadata& mutable_input_metadata(size_t index) {
return input_metadata_[index];
}
/**
* Note: Function Streams
* A function's stream (for a given device type) is the stream of the first

View File

@ -26,7 +26,7 @@ struct InputBuffer {
// Accumulates the variable at a specified index.
// The optional CUDA streams determine which stream the accumulation
// is run on and how the addition is synchronized.
void add(
TORCH_API void add(
size_t pos,
Variable&& var,
const c10::optional<c10::Stream>& opt_producer_stream,

View File

@ -140,10 +140,21 @@ struct InputMetadata {
return was_default_constructed_;
}
private:
bool is_nested_tensor() const {
return (c10::holds_alternative<at::Tensor>(shape_));
}
c10::SymIntArrayRef shape_as_dim_vector() const {
const auto& dim_shape = c10::get<SymIntSmallVec>(shape_);
return c10::SymIntArrayRef(dim_shape.data(), dim_shape.size());
}
// Danger: not thread safe, caller must protect with lock
SymIntSmallVec& mutable_shape_as_dim_vector() {
return c10::get<SymIntSmallVec>(shape_);
}
private:
MetadataShape compute_variant_shape(const at::Tensor& input) {
if (input.is_nested()) {
auto nested_size = at::native::get_nested_sizes(input);
@ -152,11 +163,6 @@ struct InputMetadata {
return MetadataShape{c10::in_place_type<SymIntSmallVec>, input.sym_sizes()};
}
c10::SymIntArrayRef shape_as_dim_vector() const {
const auto& dim_shape = c10::get<SymIntSmallVec>(shape_);
return c10::SymIntArrayRef(dim_shape.data(), dim_shape.size());
}
at::Tensor shape_as_tensor() const {
return c10::get<at::Tensor>(shape_);
}

View File

@ -48,6 +48,10 @@ class TORCH_API SavedVariable {
void reset_data();
bool has_hooks() const {
return (bool)hooks_;
}
private:
// This field contains either:
// 1. the variable to save