mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 16:04:58 +08:00 
			
		
		
		
	Summary: We need to properly fakify torchbind objects, including the ones in graph module attributes, so the resgitered fake implementation works properly. - _fakify_script_objects in `compile_fx` - Allow fake torchbind objects in `torchbind_constants` Remove `node.meta["unbacked_bindings"]` for `aot_compile` in `compile_fx`. Otherwise `ShapeProp` will fail when trying to resolve the `unbacked_bindings` of `with_effect` tokens. Update `sigrid_transforms_test` to use the latest `torch._inductor.aot_compile` API. Add a test for `Fakify torchbind objects in compile_fx and add tests for SigridTransformsInstanceTorchBind` in `e2e_test`. Test Plan: ``` buck run //caffe2/torch/fb/sparsenn:sigrid_test -- -r test_transform_torch_bind buck run //sigmoid/inference/test:e2e_test_cpu -- -r SigridTransforms buck2 run mode/dev-nosan sigmoid/inference/ts_migration:pt2i_readiness_main -- --model_id 545017754 --test_suite ads_all --mode test_preproc ``` Differential Revision: D70013257 Pull Request resolved: https://github.com/pytorch/pytorch/pull/149529 Approved by: https://github.com/angelayi
		
			
				
	
	
		
			762 lines
		
	
	
		
			24 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			762 lines
		
	
	
		
			24 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
#include <test/cpp/jit/test_custom_class_registrations.h>
 | 
						|
 | 
						|
#include <torch/custom_class.h>
 | 
						|
#include <torch/script.h>
 | 
						|
 | 
						|
#include <iostream>
 | 
						|
#include <string>
 | 
						|
#include <vector>
 | 
						|
 | 
						|
using namespace torch::jit;
 | 
						|
 | 
						|
namespace {
 | 
						|
 | 
						|
struct DefaultArgs : torch::CustomClassHolder {
 | 
						|
  int x;
 | 
						|
  DefaultArgs(int64_t start = 3) : x(start) {}
 | 
						|
  int64_t increment(int64_t val = 1) {
 | 
						|
    x += val;
 | 
						|
    return x;
 | 
						|
  }
 | 
						|
  int64_t decrement(int64_t val = 1) {
 | 
						|
    x += val;
 | 
						|
    return x;
 | 
						|
  }
 | 
						|
  int64_t scale_add(int64_t add, int64_t scale = 1) {
 | 
						|
    // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
 | 
						|
    x = scale * x + add;
 | 
						|
    return x;
 | 
						|
  }
 | 
						|
  int64_t divide(std::optional<int64_t> factor) {
 | 
						|
    if (factor) {
 | 
						|
      // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
 | 
						|
      x = x / *factor;
 | 
						|
    }
 | 
						|
    return x;
 | 
						|
  }
 | 
						|
};
 | 
						|
 | 
						|
struct Foo : torch::CustomClassHolder {
 | 
						|
  int x, y;
 | 
						|
  Foo() : x(0), y(0) {}
 | 
						|
  Foo(int x_, int y_) : x(x_), y(y_) {}
 | 
						|
  int64_t info() {
 | 
						|
    return this->x * this->y;
 | 
						|
  }
 | 
						|
  int64_t add(int64_t z) {
 | 
						|
    return (x + y) * z;
 | 
						|
  }
 | 
						|
  at::Tensor add_tensor(at::Tensor z) {
 | 
						|
    return (x + y) * z;
 | 
						|
  }
 | 
						|
  void increment(int64_t z) {
 | 
						|
    this->x += z;
 | 
						|
    this->y += z;
 | 
						|
  }
 | 
						|
  int64_t combine(c10::intrusive_ptr<Foo> b) {
 | 
						|
    return this->info() + b->info();
 | 
						|
  }
 | 
						|
  bool eq(c10::intrusive_ptr<Foo> other) {
 | 
						|
    return this->x == other->x && this->y == other->y;
 | 
						|
  }
 | 
						|
  std::tuple<std::tuple<std::string, int64_t>, std::tuple<std::string, int64_t>>
 | 
						|
  __obj_flatten__() {
 | 
						|
    return std::tuple(std::tuple("x", this->x), std::tuple("y", this->y));
 | 
						|
  }
 | 
						|
};
 | 
						|
 | 
						|
struct _StaticMethod : torch::CustomClassHolder {
 | 
						|
  // NOLINTNEXTLINE(modernize-use-equals-default)
 | 
						|
  _StaticMethod() {}
 | 
						|
  static int64_t staticMethod(int64_t input) {
 | 
						|
    return 2 * input;
 | 
						|
  }
 | 
						|
};
 | 
						|
 | 
						|
struct FooGetterSetter : torch::CustomClassHolder {
 | 
						|
  FooGetterSetter() : x(0), y(0) {}
 | 
						|
  FooGetterSetter(int64_t x_, int64_t y_) : x(x_), y(y_) {}
 | 
						|
 | 
						|
  int64_t getX() {
 | 
						|
    // to make sure this is not just attribute lookup
 | 
						|
    return x + 2;
 | 
						|
  }
 | 
						|
  void setX(int64_t z) {
 | 
						|
    // to make sure this is not just attribute lookup
 | 
						|
    x = z + 2;
 | 
						|
  }
 | 
						|
 | 
						|
  int64_t getY() {
 | 
						|
    // to make sure this is not just attribute lookup
 | 
						|
    return y + 4;
 | 
						|
  }
 | 
						|
 | 
						|
 private:
 | 
						|
  int64_t x, y;
 | 
						|
};
 | 
						|
 | 
						|
struct FooGetterSetterLambda : torch::CustomClassHolder {
 | 
						|
  int64_t x;
 | 
						|
  FooGetterSetterLambda() : x(0) {}
 | 
						|
  FooGetterSetterLambda(int64_t x_) : x(x_) {}
 | 
						|
};
 | 
						|
 | 
						|
struct FooReadWrite : torch::CustomClassHolder {
 | 
						|
  int64_t x;
 | 
						|
  const int64_t y;
 | 
						|
  FooReadWrite() : x(0), y(0) {}
 | 
						|
  FooReadWrite(int64_t x_, int64_t y_) : x(x_), y(y_) {}
 | 
						|
};
 | 
						|
 | 
						|
struct LambdaInit : torch::CustomClassHolder {
 | 
						|
  int x, y;
 | 
						|
  LambdaInit(int x_, int y_) : x(x_), y(y_) {}
 | 
						|
  int64_t diff() {
 | 
						|
    return this->x - this->y;
 | 
						|
  }
 | 
						|
};
 | 
						|
 | 
						|
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
 | 
						|
struct NoInit : torch::CustomClassHolder {
 | 
						|
  int64_t x;
 | 
						|
};
 | 
						|
 | 
						|
struct PickleTester : torch::CustomClassHolder {
 | 
						|
  PickleTester(std::vector<int64_t> vals) : vals(std::move(vals)) {}
 | 
						|
  std::vector<int64_t> vals;
 | 
						|
};
 | 
						|
 | 
						|
// Thread-safe Tensor Queue
 | 
						|
struct TensorQueue : torch::CustomClassHolder {
 | 
						|
  explicit TensorQueue(at::Tensor t) : init_tensor_(t) {}
 | 
						|
 | 
						|
  explicit TensorQueue(c10::Dict<std::string, at::Tensor> dict) {
 | 
						|
    init_tensor_ = dict.at(std::string("init_tensor"));
 | 
						|
    const std::string key = "queue";
 | 
						|
    at::Tensor size_tensor;
 | 
						|
    size_tensor = dict.at(std::string(key + "/size")).cpu();
 | 
						|
    const auto* size_tensor_acc = size_tensor.const_data_ptr<int64_t>();
 | 
						|
    int64_t queue_size = size_tensor_acc[0];
 | 
						|
 | 
						|
    for (const auto index : c10::irange(queue_size)) {
 | 
						|
      at::Tensor val;
 | 
						|
      queue_[index] = dict.at(key + "/" + std::to_string(index));
 | 
						|
      queue_.push_back(val);
 | 
						|
    }
 | 
						|
  }
 | 
						|
 | 
						|
  std::tuple<
 | 
						|
      std::tuple<std::string, at::Tensor>,
 | 
						|
      std::tuple<std::string, std::vector<at::Tensor>>>
 | 
						|
  serialize() {
 | 
						|
    return std::tuple(
 | 
						|
        std::tuple("init_tensor", this->init_tensor_.clone()),
 | 
						|
        std::tuple("queue", this->clone_queue()));
 | 
						|
  }
 | 
						|
 | 
						|
  static c10::intrusive_ptr<TensorQueue> deserialize(
 | 
						|
      std::tuple<
 | 
						|
          std::tuple<std::string, at::Tensor>,
 | 
						|
          std::tuple<std::string, std::vector<at::Tensor>>> flattened) {
 | 
						|
    TORCH_CHECK(std::tuple_size<decltype(flattened)>::value == 2);
 | 
						|
 | 
						|
    auto init_tensor_tuple = std::get<0>(flattened);
 | 
						|
    TORCH_CHECK(std::tuple_size<decltype(init_tensor_tuple)>::value == 2);
 | 
						|
    TORCH_CHECK(std::get<0>(init_tensor_tuple) == std::string("init_tensor"));
 | 
						|
 | 
						|
    c10::intrusive_ptr<TensorQueue> queue =
 | 
						|
        c10::make_intrusive<TensorQueue>(std::get<1>(init_tensor_tuple));
 | 
						|
 | 
						|
    auto queue_tuple = std::get<1>(flattened);
 | 
						|
    TORCH_CHECK(std::tuple_size<decltype(queue_tuple)>::value == 2);
 | 
						|
    TORCH_CHECK(std::get<0>(queue_tuple) == std::string("queue"));
 | 
						|
 | 
						|
    for (auto& value : std::get<1>(queue_tuple)) {
 | 
						|
      queue->push(value);
 | 
						|
    }
 | 
						|
 | 
						|
    return queue;
 | 
						|
  }
 | 
						|
 | 
						|
  // Push the element to the rear of queue.
 | 
						|
  // Lock is added for thread safe.
 | 
						|
  void push(at::Tensor x) {
 | 
						|
    std::lock_guard<std::mutex> guard(mutex_);
 | 
						|
    queue_.push_back(x);
 | 
						|
  }
 | 
						|
  // Pop the front element of queue and return it.
 | 
						|
  // If empty, return init_tensor_.
 | 
						|
  // Lock is added for thread safe.
 | 
						|
  at::Tensor pop() {
 | 
						|
    std::lock_guard<std::mutex> guard(mutex_);
 | 
						|
    if (!queue_.empty()) {
 | 
						|
      auto val = queue_.front();
 | 
						|
      queue_.pop_front();
 | 
						|
      return val;
 | 
						|
    } else {
 | 
						|
      return init_tensor_;
 | 
						|
    }
 | 
						|
  }
 | 
						|
  // Return front element of queue, read-only.
 | 
						|
  // We might further optimize with read-write lock.
 | 
						|
  at::Tensor top() {
 | 
						|
    std::lock_guard<std::mutex> guard(mutex_);
 | 
						|
    if (!queue_.empty()) {
 | 
						|
      auto val = queue_.front();
 | 
						|
      return val;
 | 
						|
    } else {
 | 
						|
      return init_tensor_;
 | 
						|
    }
 | 
						|
  }
 | 
						|
  int64_t size() {
 | 
						|
    return queue_.size();
 | 
						|
  }
 | 
						|
 | 
						|
  bool is_empty() {
 | 
						|
    std::lock_guard<std::mutex> guard(mutex_);
 | 
						|
    return queue_.empty();
 | 
						|
  }
 | 
						|
 | 
						|
  double float_size() {
 | 
						|
    return 1. * queue_.size();
 | 
						|
  }
 | 
						|
 | 
						|
  std::vector<at::Tensor> clone_queue() {
 | 
						|
    std::lock_guard<std::mutex> guard(mutex_);
 | 
						|
    std::vector<at::Tensor> ret;
 | 
						|
    for (const auto& t : queue_) {
 | 
						|
      ret.push_back(t.clone());
 | 
						|
    }
 | 
						|
    return ret;
 | 
						|
  }
 | 
						|
  std::vector<at::Tensor> get_raw_queue() {
 | 
						|
    std::vector<at::Tensor> raw_queue(queue_.begin(), queue_.end());
 | 
						|
    return raw_queue;
 | 
						|
  }
 | 
						|
 | 
						|
  std::tuple<std::tuple<std::string, std::vector<at::Tensor>>> __obj_flatten__() {
 | 
						|
    return std::tuple(std::tuple("queue", this->get_raw_queue()));
 | 
						|
  }
 | 
						|
 | 
						|
 private:
 | 
						|
  std::deque<at::Tensor> queue_;
 | 
						|
  std::mutex mutex_;
 | 
						|
  at::Tensor init_tensor_;
 | 
						|
};
 | 
						|
 | 
						|
struct ConstantTensorContainer : torch::CustomClassHolder {
 | 
						|
  explicit ConstantTensorContainer(at::Tensor x) : x_(x) {}
 | 
						|
 | 
						|
  at::Tensor get() {
 | 
						|
    return x_;
 | 
						|
  }
 | 
						|
 | 
						|
  std::string tracing_mode() {
 | 
						|
    return "real";
 | 
						|
  }
 | 
						|
 | 
						|
 private:
 | 
						|
  at::Tensor x_;
 | 
						|
};
 | 
						|
 | 
						|
at::Tensor take_an_instance(const c10::intrusive_ptr<PickleTester>& instance) {
 | 
						|
  return torch::zeros({instance->vals.back(), 4});
 | 
						|
}
 | 
						|
 | 
						|
struct ElementwiseInterpreter : torch::CustomClassHolder {
 | 
						|
  using InstructionType = std::tuple<
 | 
						|
      std::string /*op*/,
 | 
						|
      std::vector<std::string> /*inputs*/,
 | 
						|
      std::string /*output*/>;
 | 
						|
 | 
						|
  // NOLINTNEXTLINE(modernize-use-equals-default)
 | 
						|
  ElementwiseInterpreter() {}
 | 
						|
 | 
						|
  // Load a list of instructions into the interpreter. As specified above,
 | 
						|
  // instructions specify the operation (currently support "add" and "mul"),
 | 
						|
  // the names of the input values, and the name of the single output value
 | 
						|
  // from this instruction
 | 
						|
  void setInstructions(std::vector<InstructionType> instructions) {
 | 
						|
    instructions_ = std::move(instructions);
 | 
						|
  }
 | 
						|
 | 
						|
  // Add a constant. The interpreter maintains a set of constants across
 | 
						|
  // calls. They are keyed by name, and constants can be referenced in
 | 
						|
  // Instructions by the name specified
 | 
						|
  void addConstant(const std::string& name, at::Tensor value) {
 | 
						|
    constants_.insert_or_assign(name, std::move(value));
 | 
						|
  }
 | 
						|
 | 
						|
  // Set the string names for the positional inputs to the function this
 | 
						|
  // interpreter represents. When invoked, the interpreter will assign
 | 
						|
  // the positional inputs to the names in the corresponding position in
 | 
						|
  // input_names.
 | 
						|
  void setInputNames(std::vector<std::string> input_names) {
 | 
						|
    input_names_ = std::move(input_names);
 | 
						|
  }
 | 
						|
 | 
						|
  // Specify the output name for the function this interpreter represents. This
 | 
						|
  // should match the "output" field of one of the instructions in the
 | 
						|
  // instruction list, typically the last instruction.
 | 
						|
  void setOutputName(std::string output_name) {
 | 
						|
    output_name_ = std::move(output_name);
 | 
						|
  }
 | 
						|
 | 
						|
  // Invoke this interpreter. This takes a list of positional inputs and returns
 | 
						|
  // a single output. Currently, inputs and outputs must all be Tensors.
 | 
						|
  at::Tensor __call__(std::vector<at::Tensor> inputs) {
 | 
						|
    // Environment to hold local variables
 | 
						|
    std::unordered_map<std::string, at::Tensor> environment;
 | 
						|
 | 
						|
    // Load inputs according to the specified names
 | 
						|
    if (inputs.size() != input_names_.size()) {
 | 
						|
      std::stringstream err;
 | 
						|
      err << "Expected " << input_names_.size() << " inputs, but got "
 | 
						|
          << inputs.size() << "!";
 | 
						|
      throw std::runtime_error(err.str());
 | 
						|
    }
 | 
						|
    for (size_t i = 0; i < inputs.size(); ++i) {
 | 
						|
      environment[input_names_[i]] = inputs[i];
 | 
						|
    }
 | 
						|
 | 
						|
    for (InstructionType& instr : instructions_) {
 | 
						|
      // Retrieve all input values for this op
 | 
						|
      std::vector<at::Tensor> inputs;
 | 
						|
      for (const auto& input_name : std::get<1>(instr)) {
 | 
						|
        // Operator output values shadow constants.
 | 
						|
        // Imagine all constants are defined in statements at the beginning
 | 
						|
        // of a function (a la K&R C). Any definition of an output value must
 | 
						|
        // necessarily come after constant definition in textual order. Thus,
 | 
						|
        // We look up values in the environment first then the constant table
 | 
						|
        // second to implement this shadowing behavior
 | 
						|
        if (environment.find(input_name) != environment.end()) {
 | 
						|
          inputs.push_back(environment.at(input_name));
 | 
						|
        } else if (constants_.find(input_name) != constants_.end()) {
 | 
						|
          inputs.push_back(constants_.at(input_name));
 | 
						|
        } else {
 | 
						|
          std::stringstream err;
 | 
						|
          err << "Instruction referenced unknown value " << input_name << "!";
 | 
						|
          throw std::runtime_error(err.str());
 | 
						|
        }
 | 
						|
      }
 | 
						|
 | 
						|
      // Run the specified operation
 | 
						|
      at::Tensor result;
 | 
						|
      const auto& op = std::get<0>(instr);
 | 
						|
      if (op == "add") {
 | 
						|
        if (inputs.size() != 2) {
 | 
						|
          throw std::runtime_error("Unexpected number of inputs for add op!");
 | 
						|
        }
 | 
						|
        result = inputs[0] + inputs[1];
 | 
						|
      } else if (op == "mul") {
 | 
						|
        if (inputs.size() != 2) {
 | 
						|
          throw std::runtime_error("Unexpected number of inputs for mul op!");
 | 
						|
        }
 | 
						|
        result = inputs[0] * inputs[1];
 | 
						|
      } else {
 | 
						|
        std::stringstream err;
 | 
						|
        err << "Unknown operator " << op << "!";
 | 
						|
        throw std::runtime_error(err.str());
 | 
						|
      }
 | 
						|
 | 
						|
      // Write back result into environment
 | 
						|
      const auto& output_name = std::get<2>(instr);
 | 
						|
      environment[output_name] = std::move(result);
 | 
						|
    }
 | 
						|
 | 
						|
    if (!output_name_) {
 | 
						|
      throw std::runtime_error("Output name not specified!");
 | 
						|
    }
 | 
						|
 | 
						|
    return environment.at(*output_name_);
 | 
						|
  }
 | 
						|
 | 
						|
  // Ser/De infrastructure. See
 | 
						|
  // https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html#defining-serialization-deserialization-methods-for-custom-c-classes
 | 
						|
  // for more info.
 | 
						|
 | 
						|
  // This is the type we will use to marshall information on disk during
 | 
						|
  // ser/de. It is a simple tuple composed of primitive types and simple
 | 
						|
  // collection types like vector, optional, and dict.
 | 
						|
  using SerializationType = std::tuple<
 | 
						|
      std::vector<std::string> /*input_names_*/,
 | 
						|
      std::optional<std::string> /*output_name_*/,
 | 
						|
      c10::Dict<std::string, at::Tensor> /*constants_*/,
 | 
						|
      std::vector<InstructionType> /*instructions_*/
 | 
						|
      >;
 | 
						|
 | 
						|
  // This function yields the SerializationType instance for `this`.
 | 
						|
  SerializationType __getstate__() const {
 | 
						|
    return SerializationType{
 | 
						|
        input_names_, output_name_, constants_, instructions_};
 | 
						|
  }
 | 
						|
 | 
						|
  // This function will create an instance of `ElementwiseInterpreter` given
 | 
						|
  // an instance of `SerializationType`.
 | 
						|
  static c10::intrusive_ptr<ElementwiseInterpreter> __setstate__(
 | 
						|
      SerializationType state) {
 | 
						|
    auto instance = c10::make_intrusive<ElementwiseInterpreter>();
 | 
						|
    std::tie(
 | 
						|
        instance->input_names_,
 | 
						|
        instance->output_name_,
 | 
						|
        instance->constants_,
 | 
						|
        instance->instructions_) = std::move(state);
 | 
						|
    return instance;
 | 
						|
  }
 | 
						|
 | 
						|
  // Class members
 | 
						|
  std::vector<std::string> input_names_;
 | 
						|
  std::optional<std::string> output_name_;
 | 
						|
  c10::Dict<std::string, at::Tensor> constants_;
 | 
						|
  std::vector<InstructionType> instructions_;
 | 
						|
};
 | 
						|
 | 
						|
struct ReLUClass : public torch::CustomClassHolder {
 | 
						|
  at::Tensor run(const at::Tensor& t) {
 | 
						|
    return t.relu();
 | 
						|
  }
 | 
						|
};
 | 
						|
 | 
						|
struct FlattenWithTensorOp : public torch::CustomClassHolder {
 | 
						|
  explicit FlattenWithTensorOp(at::Tensor t) : t_(t) {}
 | 
						|
 | 
						|
  at::Tensor get() {
 | 
						|
    return t_;
 | 
						|
  }
 | 
						|
 | 
						|
  std::tuple<std::tuple<std::string, at::Tensor>> __obj_flatten__() {
 | 
						|
    return std::tuple(std::tuple("t", this->t_.sin()));
 | 
						|
  }
 | 
						|
 | 
						|
 private:
 | 
						|
  at::Tensor t_;
 | 
						|
  ;
 | 
						|
};
 | 
						|
 | 
						|
struct ContainsTensor : public torch::CustomClassHolder {
 | 
						|
  explicit ContainsTensor(at::Tensor t) : t_(t) {}
 | 
						|
 | 
						|
  at::Tensor get() {
 | 
						|
    return t_;
 | 
						|
  }
 | 
						|
 | 
						|
  std::tuple<std::tuple<std::string, at::Tensor>> __obj_flatten__() {
 | 
						|
    return std::tuple(std::tuple("t", this->t_));
 | 
						|
  }
 | 
						|
 | 
						|
  at::Tensor t_;
 | 
						|
};
 | 
						|
 | 
						|
TORCH_LIBRARY(_TorchScriptTesting, m) {
 | 
						|
  m.impl_abstract_pystub("torch.testing._internal.torchbind_impls");
 | 
						|
  m.class_<ScalarTypeClass>("_ScalarTypeClass")
 | 
						|
      .def(torch::init<at::ScalarType>())
 | 
						|
      .def_pickle(
 | 
						|
          [](const c10::intrusive_ptr<ScalarTypeClass>& self) {
 | 
						|
            return std::make_tuple(self->scalar_type_);
 | 
						|
          },
 | 
						|
          [](std::tuple<at::ScalarType> s) {
 | 
						|
            return c10::make_intrusive<ScalarTypeClass>(std::get<0>(s));
 | 
						|
          });
 | 
						|
 | 
						|
  m.class_<ReLUClass>("_ReLUClass")
 | 
						|
      .def(torch::init<>())
 | 
						|
      .def("run", &ReLUClass::run);
 | 
						|
 | 
						|
  m.class_<_StaticMethod>("_StaticMethod")
 | 
						|
      .def(torch::init<>())
 | 
						|
      .def_static("staticMethod", &_StaticMethod::staticMethod);
 | 
						|
 | 
						|
  m.class_<DefaultArgs>("_DefaultArgs")
 | 
						|
      .def(torch::init<int64_t>(), "", {torch::arg("start") = 3})
 | 
						|
      .def("increment", &DefaultArgs::increment, "", {torch::arg("val") = 1})
 | 
						|
      .def("decrement", &DefaultArgs::decrement, "", {torch::arg("val") = 1})
 | 
						|
      .def(
 | 
						|
          "scale_add",
 | 
						|
          &DefaultArgs::scale_add,
 | 
						|
          "",
 | 
						|
          {torch::arg("add"), torch::arg("scale") = 1})
 | 
						|
      .def(
 | 
						|
          "divide",
 | 
						|
          &DefaultArgs::divide,
 | 
						|
          "",
 | 
						|
          {torch::arg("factor") = torch::arg::none()});
 | 
						|
 | 
						|
  m.class_<Foo>("_Foo")
 | 
						|
      .def(torch::init<int64_t, int64_t>())
 | 
						|
      // .def(torch::init<>())
 | 
						|
      .def("info", &Foo::info)
 | 
						|
      .def("increment", &Foo::increment)
 | 
						|
      .def("add", &Foo::add)
 | 
						|
      .def("add_tensor", &Foo::add_tensor)
 | 
						|
      .def("__eq__", &Foo::eq)
 | 
						|
      .def("combine", &Foo::combine)
 | 
						|
      .def("__obj_flatten__", &Foo::__obj_flatten__)
 | 
						|
      .def_pickle(
 | 
						|
          [](c10::intrusive_ptr<Foo> self) { // __getstate__
 | 
						|
            return std::vector<int64_t>{self->x, self->y};
 | 
						|
          },
 | 
						|
          [](std::vector<int64_t> state) { // __setstate__
 | 
						|
            return c10::make_intrusive<Foo>(state[0], state[1]);
 | 
						|
          });
 | 
						|
 | 
						|
  m.class_<FlattenWithTensorOp>("_FlattenWithTensorOp")
 | 
						|
      .def(torch::init<at::Tensor>())
 | 
						|
      .def("get", &FlattenWithTensorOp::get)
 | 
						|
      .def("__obj_flatten__", &FlattenWithTensorOp::__obj_flatten__);
 | 
						|
 | 
						|
  m.class_<ConstantTensorContainer>("_ConstantTensorContainer")
 | 
						|
      .def(torch::init<at::Tensor>())
 | 
						|
      .def("get", &ConstantTensorContainer::get)
 | 
						|
      .def("tracing_mode", &ConstantTensorContainer::tracing_mode);
 | 
						|
 | 
						|
  m.def(
 | 
						|
      "takes_foo(__torch__.torch.classes._TorchScriptTesting._Foo foo, Tensor x) -> Tensor");
 | 
						|
  m.def(
 | 
						|
      "takes_foo_python_meta(__torch__.torch.classes._TorchScriptTesting._Foo foo, Tensor x) -> Tensor");
 | 
						|
  m.def(
 | 
						|
      "takes_foo_list_return(__torch__.torch.classes._TorchScriptTesting._Foo foo, Tensor x) -> Tensor[]");
 | 
						|
  m.def(
 | 
						|
      "takes_foo_tuple_return(__torch__.torch.classes._TorchScriptTesting._Foo foo, Tensor x) -> (Tensor, Tensor)");
 | 
						|
  m.def(
 | 
						|
      "takes_foo_tensor_return(__torch__.torch.classes._TorchScriptTesting._Foo foo, Tensor x) -> Tensor");
 | 
						|
 | 
						|
  m.class_<FooGetterSetter>("_FooGetterSetter")
 | 
						|
      .def(torch::init<int64_t, int64_t>())
 | 
						|
      .def_property("x", &FooGetterSetter::getX, &FooGetterSetter::setX)
 | 
						|
      .def_property("y", &FooGetterSetter::getY);
 | 
						|
 | 
						|
  m.class_<FooGetterSetterLambda>("_FooGetterSetterLambda")
 | 
						|
      .def(torch::init<int64_t>())
 | 
						|
      .def_property(
 | 
						|
          "x",
 | 
						|
          [](const c10::intrusive_ptr<FooGetterSetterLambda>& self) {
 | 
						|
            return self->x;
 | 
						|
          },
 | 
						|
          [](const c10::intrusive_ptr<FooGetterSetterLambda>& self,
 | 
						|
             int64_t val) { self->x = val; });
 | 
						|
 | 
						|
  m.class_<FooReadWrite>("_FooReadWrite")
 | 
						|
      .def(torch::init<int64_t, int64_t>())
 | 
						|
      .def_readwrite("x", &FooReadWrite::x)
 | 
						|
      .def_readonly("y", &FooReadWrite::y);
 | 
						|
 | 
						|
  m.class_<LambdaInit>("_LambdaInit")
 | 
						|
      .def(torch::init([](int64_t x, int64_t y, bool swap) {
 | 
						|
        if (swap) {
 | 
						|
          return c10::make_intrusive<LambdaInit>(y, x);
 | 
						|
        } else {
 | 
						|
          return c10::make_intrusive<LambdaInit>(x, y);
 | 
						|
        }
 | 
						|
      }))
 | 
						|
      .def("diff", &LambdaInit::diff);
 | 
						|
 | 
						|
  m.class_<NoInit>("_NoInit").def(
 | 
						|
      "get_x", [](const c10::intrusive_ptr<NoInit>& self) { return self->x; });
 | 
						|
 | 
						|
  m.class_<MyStackClass<std::string>>("_StackString")
 | 
						|
      .def(torch::init<std::vector<std::string>>())
 | 
						|
      .def("push", &MyStackClass<std::string>::push)
 | 
						|
      .def("pop", &MyStackClass<std::string>::pop)
 | 
						|
      .def("clone", &MyStackClass<std::string>::clone)
 | 
						|
      .def("merge", &MyStackClass<std::string>::merge)
 | 
						|
      .def_pickle(
 | 
						|
          [](const c10::intrusive_ptr<MyStackClass<std::string>>& self) {
 | 
						|
            return self->stack_;
 | 
						|
          },
 | 
						|
          [](std::vector<std::string> state) { // __setstate__
 | 
						|
            return c10::make_intrusive<MyStackClass<std::string>>(
 | 
						|
                std::vector<std::string>{"i", "was", "deserialized"});
 | 
						|
          })
 | 
						|
      .def("return_a_tuple", &MyStackClass<std::string>::return_a_tuple)
 | 
						|
      .def(
 | 
						|
          "top",
 | 
						|
          [](const c10::intrusive_ptr<MyStackClass<std::string>>& self)
 | 
						|
              -> std::string { return self->stack_.back(); })
 | 
						|
      .def(
 | 
						|
          "__str__",
 | 
						|
          [](const c10::intrusive_ptr<MyStackClass<std::string>>& self) {
 | 
						|
            std::stringstream ss;
 | 
						|
            ss << "[";
 | 
						|
            for (size_t i = 0; i < self->stack_.size(); ++i) {
 | 
						|
              ss << self->stack_[i];
 | 
						|
              if (i != self->stack_.size() - 1) {
 | 
						|
                ss << ", ";
 | 
						|
              }
 | 
						|
            }
 | 
						|
            ss << "]";
 | 
						|
            return ss.str();
 | 
						|
          });
 | 
						|
  // clang-format off
 | 
						|
        // The following will fail with a static assert telling you you have to
 | 
						|
        // take an intrusive_ptr<MyStackClass> as the first argument.
 | 
						|
        // .def("foo", [](int64_t a) -> int64_t{ return 3;});
 | 
						|
  // clang-format on
 | 
						|
 | 
						|
  m.class_<PickleTester>("_PickleTester")
 | 
						|
      .def(torch::init<std::vector<int64_t>>())
 | 
						|
      .def_pickle(
 | 
						|
          [](c10::intrusive_ptr<PickleTester> self) { // __getstate__
 | 
						|
            return std::vector<int64_t>{1, 3, 3, 7};
 | 
						|
          },
 | 
						|
          [](std::vector<int64_t> state) { // __setstate__
 | 
						|
            return c10::make_intrusive<PickleTester>(std::move(state));
 | 
						|
          })
 | 
						|
      .def(
 | 
						|
          "top",
 | 
						|
          [](const c10::intrusive_ptr<PickleTester>& self) {
 | 
						|
            return self->vals.back();
 | 
						|
          })
 | 
						|
      .def("pop", [](const c10::intrusive_ptr<PickleTester>& self) {
 | 
						|
        auto val = self->vals.back();
 | 
						|
        self->vals.pop_back();
 | 
						|
        return val;
 | 
						|
      });
 | 
						|
 | 
						|
  m.def(
 | 
						|
      "take_an_instance(__torch__.torch.classes._TorchScriptTesting._PickleTester x) -> Tensor Y",
 | 
						|
      take_an_instance);
 | 
						|
  // test that schema inference is ok too
 | 
						|
  m.def("take_an_instance_inferred", take_an_instance);
 | 
						|
 | 
						|
  m.class_<ElementwiseInterpreter>("_ElementwiseInterpreter")
 | 
						|
      .def(torch::init<>())
 | 
						|
      .def("set_instructions", &ElementwiseInterpreter::setInstructions)
 | 
						|
      .def("add_constant", &ElementwiseInterpreter::addConstant)
 | 
						|
      .def("set_input_names", &ElementwiseInterpreter::setInputNames)
 | 
						|
      .def("set_output_name", &ElementwiseInterpreter::setOutputName)
 | 
						|
      .def("__call__", &ElementwiseInterpreter::__call__)
 | 
						|
      .def_pickle(
 | 
						|
          /* __getstate__ */
 | 
						|
          [](const c10::intrusive_ptr<ElementwiseInterpreter>& self) {
 | 
						|
            return self->__getstate__();
 | 
						|
          },
 | 
						|
          /* __setstate__ */
 | 
						|
          [](ElementwiseInterpreter::SerializationType state) {
 | 
						|
            return ElementwiseInterpreter::__setstate__(std::move(state));
 | 
						|
          });
 | 
						|
 | 
						|
  m.class_<ContainsTensor>("_ContainsTensor")
 | 
						|
      .def(torch::init<at::Tensor>())
 | 
						|
      .def("get", &ContainsTensor::get)
 | 
						|
      .def("__obj_flatten__", &ContainsTensor::__obj_flatten__)
 | 
						|
      .def_pickle(
 | 
						|
          // __getstate__
 | 
						|
          [](const c10::intrusive_ptr<ContainsTensor>& self) -> at::Tensor {
 | 
						|
            return self->t_;
 | 
						|
          },
 | 
						|
          // __setstate__
 | 
						|
          [](at::Tensor data) -> c10::intrusive_ptr<ContainsTensor> {
 | 
						|
            return c10::make_intrusive<ContainsTensor>(std::move(data));
 | 
						|
          });
 | 
						|
  m.class_<TensorQueue>("_TensorQueue")
 | 
						|
      .def(torch::init<at::Tensor>())
 | 
						|
      .def("push", &TensorQueue::push)
 | 
						|
      .def("pop", &TensorQueue::pop)
 | 
						|
      .def("top", &TensorQueue::top)
 | 
						|
      .def("is_empty", &TensorQueue::is_empty)
 | 
						|
      .def("float_size", &TensorQueue::float_size)
 | 
						|
      .def("size", &TensorQueue::size)
 | 
						|
      .def("clone_queue", &TensorQueue::clone_queue)
 | 
						|
      .def("get_raw_queue", &TensorQueue::get_raw_queue)
 | 
						|
      .def("__obj_flatten__", &TensorQueue::__obj_flatten__)
 | 
						|
      .def_pickle(
 | 
						|
          // __getstate__
 | 
						|
          [](const c10::intrusive_ptr<TensorQueue>& self)
 | 
						|
              -> std::tuple<
 | 
						|
                  std::tuple<std::string, at::Tensor>,
 | 
						|
                  std::tuple<std::string, std::vector<at::Tensor>>> {
 | 
						|
            return self->serialize();
 | 
						|
          },
 | 
						|
          // __setstate__
 | 
						|
          [](std::tuple<
 | 
						|
              std::tuple<std::string, at::Tensor>,
 | 
						|
              std::tuple<std::string, std::vector<at::Tensor>>> data)
 | 
						|
              -> c10::intrusive_ptr<TensorQueue> {
 | 
						|
            return TensorQueue::deserialize(data);
 | 
						|
          });
 | 
						|
}
 | 
						|
 | 
						|
at::Tensor takes_foo(c10::intrusive_ptr<Foo> foo, at::Tensor x) {
 | 
						|
  return foo->add_tensor(x);
 | 
						|
}
 | 
						|
 | 
						|
std::vector<at::Tensor> takes_foo_list_return(
 | 
						|
    c10::intrusive_ptr<Foo> foo,
 | 
						|
    at::Tensor x) {
 | 
						|
  std::vector<at::Tensor> result;
 | 
						|
  result.reserve(3);
 | 
						|
  auto a = foo->add_tensor(x);
 | 
						|
  auto b = foo->add_tensor(a);
 | 
						|
  auto c = foo->add_tensor(b);
 | 
						|
  result.push_back(a);
 | 
						|
  result.push_back(b);
 | 
						|
  result.push_back(c);
 | 
						|
  return result;
 | 
						|
}
 | 
						|
 | 
						|
std::tuple<at::Tensor, at::Tensor> takes_foo_tuple_return(
 | 
						|
    c10::intrusive_ptr<Foo> foo,
 | 
						|
    at::Tensor x) {
 | 
						|
  auto a = foo->add_tensor(x);
 | 
						|
  auto b = foo->add_tensor(a);
 | 
						|
  return std::make_tuple(a, b);
 | 
						|
}
 | 
						|
 | 
						|
at::Tensor takes_foo_tensor_return(c10::intrusive_ptr<Foo> foo, at::Tensor x) {
 | 
						|
  return at::ones({foo->x, foo->y}, at::device(at::kCPU).dtype(at::kInt));
 | 
						|
}
 | 
						|
 | 
						|
void queue_push(c10::intrusive_ptr<TensorQueue> tq, at::Tensor x) {
 | 
						|
  tq->push(x);
 | 
						|
}
 | 
						|
 | 
						|
at::Tensor queue_pop(c10::intrusive_ptr<TensorQueue> tq) {
 | 
						|
  return tq->pop();
 | 
						|
}
 | 
						|
 | 
						|
int64_t queue_size(c10::intrusive_ptr<TensorQueue> tq) {
 | 
						|
  return tq->size();
 | 
						|
}
 | 
						|
 | 
						|
TORCH_LIBRARY_FRAGMENT(_TorchScriptTesting, m) {
 | 
						|
  m.impl_abstract_pystub("torch.testing._internal.torchbind_impls");
 | 
						|
  m.def(
 | 
						|
      "takes_foo_cia(__torch__.torch.classes._TorchScriptTesting._Foo foo, Tensor x) -> Tensor");
 | 
						|
  m.def(
 | 
						|
      "queue_pop(__torch__.torch.classes._TorchScriptTesting._TensorQueue foo) -> Tensor");
 | 
						|
  m.def(
 | 
						|
      "queue_push(__torch__.torch.classes._TorchScriptTesting._TensorQueue foo, Tensor x) -> ()");
 | 
						|
  m.def(
 | 
						|
      "queue_size(__torch__.torch.classes._TorchScriptTesting._TensorQueue foo) -> int");
 | 
						|
}
 | 
						|
 | 
						|
TORCH_LIBRARY_IMPL(_TorchScriptTesting, CPU, m) {
 | 
						|
  m.impl("takes_foo", takes_foo);
 | 
						|
  m.impl("takes_foo_list_return", takes_foo_list_return);
 | 
						|
  m.impl("takes_foo_tuple_return", takes_foo_tuple_return);
 | 
						|
  m.impl("queue_push", queue_push);
 | 
						|
  m.impl("queue_pop", queue_pop);
 | 
						|
  m.impl("queue_size", queue_size);
 | 
						|
  m.impl("takes_foo_tensor_return", takes_foo_tensor_return);
 | 
						|
}
 | 
						|
 | 
						|
TORCH_LIBRARY_IMPL(_TorchScriptTesting, Meta, m) {
 | 
						|
  m.impl("takes_foo", &takes_foo);
 | 
						|
  m.impl("takes_foo_list_return", takes_foo_list_return);
 | 
						|
  m.impl("takes_foo_tuple_return", takes_foo_tuple_return);
 | 
						|
}
 | 
						|
 | 
						|
TORCH_LIBRARY_IMPL(_TorchScriptTesting, CompositeImplicitAutograd, m) {
 | 
						|
  m.impl("takes_foo_cia", takes_foo);
 | 
						|
}
 | 
						|
 | 
						|
// Need to implement BackendSelect because these two operators don't have tensor
 | 
						|
// inputs.
 | 
						|
TORCH_LIBRARY_IMPL(_TorchScriptTesting, BackendSelect, m) {
 | 
						|
  m.impl("queue_pop", queue_pop);
 | 
						|
  m.impl("queue_size", queue_size);
 | 
						|
}
 | 
						|
 | 
						|
} // namespace
 |