mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-31 20:34:54 +08:00 
			
		
		
		
	Summary:
This is an automatic change generated by the following script:
```
#!/usr/bin/env python3
from subprocess import check_output, check_call
import os
def get_compiled_files_list():
    import json
    with open("build/compile_commands.json") as f:
        data = json.load(f)
    files = [os.path.relpath(node['file']) for node in data]
    for idx, fname in enumerate(files):
        if fname.startswith('build/') and fname.endswith('.DEFAULT.cpp'):
            files[idx] = fname[len('build/'):-len('.DEFAULT.cpp')]
    return files
def run_clang_tidy(fname):
    check_call(["python3", "tools/clang_tidy.py", "-c", "build", "-x", fname,"-s"])
    changes = check_output(["git", "ls-files", "-m"])
    if len(changes) == 0:
        return
    check_call(["git", "commit","--all", "-m", f"NOLINT stubs for {fname}"])
def main():
    git_files = check_output(["git", "ls-files"]).decode("ascii").split("\n")
    compiled_files = get_compiled_files_list()
    for idx, fname in enumerate(git_files):
        if fname not in compiled_files:
            continue
        if fname.startswith("caffe2/contrib/aten/"):
            continue
        print(f"[{idx}/{len(git_files)}] Processing {fname}")
        run_clang_tidy(fname)
if __name__ == "__main__":
    main()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56892
Reviewed By: H-Huang
Differential Revision: D27991944
Pulled By: malfet
fbshipit-source-id: 5415e1eb2c1b34319a4f03024bfaa087007d7179
		
	
		
			
				
	
	
		
			344 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			344 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| #include <gtest/gtest.h>
 | |
| #include <torch/torch.h>
 | |
| #include <algorithm>
 | |
| #include <memory>
 | |
| #include <vector>
 | |
| 
 | |
| #include <test/cpp/api/support.h>
 | |
| 
 | |
| using namespace torch::nn;
 | |
| using namespace torch::test;
 | |
| 
 | |
| struct ModuleDictTest : torch::test::SeedingFixture {};
 | |
| 
 | |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
 | |
| TEST_F(ModuleDictTest, ConstructsFromList) {
 | |
|   struct M : Module {
 | |
|     explicit M(int value_) : value(value_) {}
 | |
|     int value;
 | |
|   };
 | |
| 
 | |
|   std::vector<std::pair<std::string, std::shared_ptr<Module>>> list = {
 | |
|     {"module_1", std::make_shared<M>(1)},
 | |
|     {"module_2", std::make_shared<M>(2)},
 | |
|     {"module_3", std::make_shared<M>(3)}
 | |
|   };
 | |
|   ModuleDict dict(list);
 | |
|   ASSERT_EQ(dict->size(), 3);
 | |
| }
 | |
| 
 | |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
 | |
| TEST_F(ModuleDictTest, ConstructsFromordereddict) {
 | |
|   struct M : Module {
 | |
|     explicit M(int value_) : value(value_) {}
 | |
|     int value;
 | |
|   };
 | |
| 
 | |
|   torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
 | |
|     {"module_1", std::make_shared<M>(1)},
 | |
|     {"module_2", std::make_shared<M>(2)},
 | |
|     {"module_3", std::make_shared<M>(3)},
 | |
|   };
 | |
|   ModuleDict dict(ordereddict);
 | |
|   ASSERT_EQ(dict->size(), 3);
 | |
| }
 | |
| 
 | |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
 | |
| TEST_F(ModuleDictTest, UpdatePopClearContains) {
 | |
|   struct M : Module {
 | |
|     explicit M(int value_) : value(value_) {}
 | |
|     int value;
 | |
|   };
 | |
| 
 | |
|   ModuleDict dict;
 | |
|   ASSERT_TRUE(dict->empty());
 | |
|   // Update by List
 | |
|   std::vector<std::pair<std::string, std::shared_ptr<Module>>> list1 = {
 | |
|     {"module_1", std::make_shared<M>(1)}
 | |
|   };
 | |
|   dict->update(list1);
 | |
|   ASSERT_EQ(dict->size(), 1);
 | |
|   ASSERT_TRUE(dict->contains("module_1"));
 | |
|   // Update by OrderedDict
 | |
|   torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
 | |
|     {"module_2", std::make_shared<M>(2)}
 | |
|   };
 | |
|   dict->update(ordereddict);
 | |
|   ASSERT_EQ(dict->size(), 2);
 | |
|   ASSERT_TRUE(dict->contains("module_2"));
 | |
|   // Update by another ModuleDict
 | |
|   std::vector<std::pair<std::string, std::shared_ptr<Module>>>list2 = {
 | |
|     {"module_3", std::make_shared<M>(3)}
 | |
|   };
 | |
|   ModuleDict updatedict(list2);
 | |
|   dict->update(*updatedict);
 | |
|   ASSERT_EQ(dict->size(), 3);
 | |
|   ASSERT_TRUE(dict->contains("module_3"));
 | |
|   // Pop
 | |
|   dict->pop("module_1");
 | |
|   ASSERT_EQ(dict->size(), 2);
 | |
|   // Pop unexist
 | |
|   ASSERT_THROWS_WITH(dict->pop("module_4"), " 'module_4' is not defined");
 | |
|   // Clear
 | |
|   dict->clear();
 | |
|   ASSERT_EQ(dict->size(), 0);
 | |
| }
 | |
| 
 | |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
 | |
| TEST_F(ModuleDictTest, UpdateExist) {
 | |
|   struct M : Module {
 | |
|     explicit M(int value_) : value(value_) {}
 | |
|     int value;
 | |
|   };
 | |
|   std::vector<std::pair<std::string, std::shared_ptr<Module>>> list1 = {
 | |
|     {"module_1", std::make_shared<M>(1)},
 | |
|     {"module_2", std::make_shared<M>(2)}
 | |
|   };
 | |
|   ModuleDict dict(list1);
 | |
|   ASSERT_EQ(dict->at<M>("module_2").value, 2);
 | |
|   // Update by list
 | |
|   std::vector<std::pair<std::string, std::shared_ptr<Module>>> list2 = {
 | |
|     {"module_2", std::make_shared<M>(0)},
 | |
|     {"module_3", std::make_shared<M>(3)}
 | |
|   };
 | |
|   dict->update(list2);
 | |
|   ASSERT_EQ(dict->size(), 3);
 | |
|   ASSERT_EQ(dict->at<M>("module_2").value, 0);
 | |
|   // Update by ordereddict
 | |
|   torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
 | |
|     {"module_3", std::make_shared<M>(0)},
 | |
|     {"module_4", std::make_shared<M>(4)}
 | |
|   };
 | |
|   dict->update(ordereddict);
 | |
|   ASSERT_EQ(dict->size(), 4);
 | |
|   ASSERT_EQ(dict->at<M>("module_3").value, 0);
 | |
|   // Update by ModuleDict
 | |
|   std::vector<std::pair<std::string, std::shared_ptr<Module>>> list3 = {
 | |
|     {"module_4", std::make_shared<M>(0)},
 | |
|     {"module_1", std::make_shared<M>(0)}
 | |
|   };
 | |
|   ModuleDict dict2(list3);
 | |
|   dict->update(*dict2);
 | |
|   ASSERT_EQ(dict->size(), 4);
 | |
|   ASSERT_EQ(dict->at<M>("module_1").value, 0);
 | |
|   ASSERT_EQ(dict->at<M>("module_4").value, 0);
 | |
| }
 | |
| 
 | |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
 | |
| TEST_F(ModuleDictTest, Keys) {
 | |
|   struct M : Module {
 | |
|     explicit M(int value_) : value(value_) {}
 | |
|     int value;
 | |
|   };
 | |
| 
 | |
|   torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
 | |
|     // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
 | |
|     {"linear", Linear(10, 3).ptr()},
 | |
|     {"conv", Conv2d(1, 2, 3).ptr()},
 | |
|     // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
 | |
|     {"dropout", Dropout(0.5).ptr()},
 | |
|   };
 | |
|   ModuleDict dict(ordereddict);
 | |
|   const auto& keys = dict->keys();
 | |
|   std::vector<std::string> expected{"linear", "conv", "dropout"};
 | |
|   ASSERT_EQ(keys, expected);
 | |
|   ASSERT_THROWS_WITH(dict["batch"], " 'batch' is not defined");
 | |
| 
 | |
|   ASSERT_TRUE(dict["linear"]->as<Linear>());
 | |
|   ASSERT_TRUE(dict["conv"]->as<Conv2d>());
 | |
|   ASSERT_TRUE(dict["dropout"]->as<Dropout>());
 | |
| }
 | |
| 
 | |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
 | |
| TEST_F(ModuleDictTest, Values) {
 | |
|   struct M : Module {
 | |
|     explicit M(int value_) : value(value_) {}
 | |
|     int value;
 | |
|   };
 | |
| 
 | |
|   torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
 | |
|     {"module_1", std::make_shared<M>(1)},
 | |
|     {"module_2", std::make_shared<M>(2)},
 | |
|   };
 | |
|   ModuleDict dict(ordereddict);
 | |
|   const auto& values = dict->values();
 | |
|   const auto& expected = ordereddict.values();
 | |
|   ASSERT_EQ(values, expected);
 | |
|   ASSERT_TRUE(std::equal(
 | |
|       dict->begin(),
 | |
|       dict->end(),
 | |
|       ordereddict.begin(),
 | |
|       [](const auto& lhs,
 | |
|          const auto& rhs) {
 | |
|         return lhs.value().get() == rhs.value().get();
 | |
|       }));
 | |
| }
 | |
| 
 | |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
 | |
| TEST_F(ModuleDictTest, SanityCheckForHoldingStandardModules) {
 | |
|   torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
 | |
|     // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
 | |
|     {"linear", Linear(10, 3).ptr()},
 | |
|     {"conv", Conv2d(1, 2, 3).ptr()},
 | |
|     // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
 | |
|     {"dropout", Dropout(0.5).ptr()},
 | |
|     // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
 | |
|     {"batch", BatchNorm2d(5).ptr()},
 | |
|     // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
 | |
|     {"embedding", Embedding(4, 10).ptr()},
 | |
|     // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
 | |
|     {"lstm", LSTM(4, 5).ptr()}
 | |
|   };
 | |
|   ModuleDict dict(ordereddict);
 | |
| }
 | |
| 
 | |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
 | |
| TEST_F(ModuleDictTest, HasReferenceSemantics) {
 | |
|   torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
 | |
|     {"linear1", Linear(2, 3).ptr()},
 | |
|     {"linear2", Linear(3, 4).ptr()},
 | |
|     // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
 | |
|     {"linear3", Linear(4, 5).ptr()},
 | |
|   };
 | |
|   ModuleDict first(ordereddict);
 | |
|   ModuleDict second(ordereddict);
 | |
| 
 | |
|   ASSERT_EQ(first->size(), second->size());
 | |
|   ASSERT_TRUE(std::equal(
 | |
|       first->begin(),
 | |
|       first->end(),
 | |
|       second->begin(),
 | |
|       [](const auto& lhs,
 | |
|          const auto& rhs) {
 | |
|         return lhs.value().get() == rhs.value().get();
 | |
|       }));
 | |
| }
 | |
| 
 | |
| void iscloneable_helper(torch::Device device) {
 | |
|   torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
 | |
|     {"linear", Linear(2, 3).ptr()},
 | |
|     {"relu", Functional(torch::relu).ptr()},
 | |
|     {"batch", BatchNorm1d(3).ptr()},
 | |
|   };
 | |
|   ModuleDict dict(ordereddict);
 | |
|   dict->to(device);
 | |
|   ModuleDict clone = std::dynamic_pointer_cast<ModuleDictImpl>(dict->clone(device));
 | |
|   ASSERT_EQ(dict->size(), clone->size());
 | |
| 
 | |
|   for (auto it = dict->begin(), it_c = clone->begin(); it != dict->end(); ++it, ++it_c) {
 | |
|     // The key should be same
 | |
|     ASSERT_EQ(it->key(), it_c->key());
 | |
|     // The modules should be the same kind (type).
 | |
|     ASSERT_EQ(it->value()->name(), it_c->value()->name());
 | |
|     // But not pointer-equal (distinct objects).
 | |
|     ASSERT_NE(it->value(), it_c->value());
 | |
|   }
 | |
| 
 | |
|   // Verify that the clone is deep, i.e. parameters of modules are cloned too.
 | |
|   torch::NoGradGuard no_grad;
 | |
| 
 | |
|   auto params1 = dict->named_parameters();
 | |
|   auto params2 = clone->named_parameters();
 | |
|   ASSERT_EQ(params1.size(), params2.size());
 | |
|   for (auto& param : params1) {
 | |
|     ASSERT_FALSE(pointer_equal(param.value(), params2[param.key()]));
 | |
|     ASSERT_EQ(param->device(), params2[param.key()].device());
 | |
|     ASSERT_TRUE(param->allclose(params2[param.key()]));
 | |
|     param->add_(2);
 | |
|   }
 | |
|   for (auto& param : params1) {
 | |
|     ASSERT_FALSE(param->allclose(params2[param.key()]));
 | |
|   }
 | |
| }
 | |
| 
 | |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
 | |
| TEST_F(ModuleDictTest, IsCloneable) {
 | |
|   iscloneable_helper(torch::kCPU);
 | |
| }
 | |
| 
 | |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
 | |
| TEST_F(ModuleDictTest, IsCloneable_CUDA) {
 | |
|   iscloneable_helper({torch::kCUDA, 0});
 | |
| }
 | |
| 
 | |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
 | |
| TEST_F(ModuleDictTest, RegistersElementsAsSubmodules) {
 | |
|   torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict1 = {
 | |
|     // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
 | |
|     {"linear", Linear(10, 3).ptr()},
 | |
|     {"conv", Conv2d(1, 2, 3).ptr()},
 | |
|     // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
 | |
|     {"test", Dropout(0.5).ptr()},
 | |
|   };
 | |
|   ModuleDict dict(ordereddict1);
 | |
| 
 | |
|   auto modules = dict->children();
 | |
|   ASSERT_TRUE(modules[0]->as<Linear>());
 | |
|   ASSERT_TRUE(modules[1]->as<Conv2d>());
 | |
|   ASSERT_TRUE(modules[2]->as<Dropout>());
 | |
| 
 | |
|   // Update Existing
 | |
|   torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict2 = {
 | |
|     // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
 | |
|     {"lstm", LSTM(4, 5).ptr()},
 | |
|     // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
 | |
|     {"test", BatchNorm2d(5).ptr()}
 | |
|   };
 | |
|   dict->update(ordereddict2);
 | |
| 
 | |
|   modules = dict->children();
 | |
|   ASSERT_TRUE(modules[0]->as<Linear>());
 | |
|   ASSERT_TRUE(modules[1]->as<Conv2d>());
 | |
|   // Keep Order
 | |
|   ASSERT_TRUE(modules[2]->as<BatchNorm2d>());
 | |
|   ASSERT_TRUE(modules[3]->as<LSTM>());
 | |
| }
 | |
| 
 | |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
 | |
| TEST_F(ModuleDictTest, CloneToDevice_CUDA) {
 | |
|   torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
 | |
|     {"linear", Linear(2, 3).ptr()},
 | |
|     {"relu", Functional(torch::relu).ptr()},
 | |
|     {"batch", BatchNorm1d(3).ptr()},
 | |
|   };
 | |
|   ModuleDict dict(ordereddict);
 | |
|   torch::Device device(torch::kCUDA, 0);
 | |
|   ModuleDict clone =
 | |
|       std::dynamic_pointer_cast<ModuleDictImpl>(dict->clone(device));
 | |
|   for (const auto& p : clone->parameters()) {
 | |
|     ASSERT_EQ(p.device(), device);
 | |
|   }
 | |
|   for (const auto& b : clone->buffers()) {
 | |
|     ASSERT_EQ(b.device(), device);
 | |
|   }
 | |
| }
 | |
| 
 | |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
 | |
| TEST_F(ModuleDictTest, PrettyPrintModuleDict) {
 | |
|   torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
 | |
|     // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
 | |
|     {"linear", Linear(10, 3).ptr()},
 | |
|     {"conv", Conv2d(1, 2, 3).ptr()},
 | |
|     // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
 | |
|     {"dropout", Dropout(0.5).ptr()},
 | |
|     // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
 | |
|     {"batch", BatchNorm2d(5).ptr()},
 | |
|     // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
 | |
|     {"embedding", Embedding(4, 10).ptr()},
 | |
|     // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
 | |
|     {"lstm", LSTM(4, 5).ptr()}
 | |
|   };
 | |
|   ModuleDict dict(ordereddict);
 | |
| 
 | |
|   ASSERT_EQ(
 | |
|       c10::str(dict),
 | |
|       "torch::nn::ModuleDict(\n"
 | |
|       "  (linear): torch::nn::Linear(in_features=10, out_features=3, bias=true)\n"
 | |
|       "  (conv): torch::nn::Conv2d(1, 2, kernel_size=[3, 3], stride=[1, 1])\n"
 | |
|       "  (dropout): torch::nn::Dropout(p=0.5, inplace=false)\n"
 | |
|       "  (batch): torch::nn::BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
 | |
|       "  (embedding): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n"
 | |
|       "  (lstm): torch::nn::LSTM(input_size=4, hidden_size=5, num_layers=1, bias=true, batch_first=false, dropout=0, bidirectional=false)\n"
 | |
|       ")");
 | |
| }
 |