Files
pytorch/test/cpp/api/parallel_benchmark.cpp
Nikita Shulga 3a66a1cb99 [clang-tidy] Exclude cppcoreguidelines-avoid-magic-numbers (#57841)
Summary:
Add cppcoreguidelines-avoid-magic-numbers exclusion to clang-tidy
Remove existing nolint warnings using following script:
```
for file in `git ls-files | grep -v \.py`; do gsed '/^ *\/\/ NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)/d' -i  $file; done
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/57841

Reviewed By: samestep

Differential Revision: D28295045

Pulled By: malfet

fbshipit-source-id: 7c6e8d1213c9593f169ed3df6a916498f1a97163
2021-05-07 20:02:33 -07:00

89 lines
2.1 KiB
C++

#include <torch/torch.h>
#include <chrono>
#include <condition_variable>
#include <mutex>
class Baton {
public:
void post() {
std::unique_lock<std::mutex> l(lock_);
done_ = true;
cv_.notify_all();
}
void wait() {
std::unique_lock<std::mutex> l(lock_);
while (!done_) {
cv_.wait(l);
}
}
private:
std::mutex lock_;
std::condition_variable cv_;
bool done_{false};
};
void AtLaunch_Base(int32_t numIters) {
struct Helper {
explicit Helper(int32_t lim) : limit_(lim) {}
void operator()() {
if (++val_ == limit_) {
done.post();
} else {
at::launch([this]() { (*this)(); });
}
}
int val_{0};
int limit_;
Baton done;
};
Helper h(numIters);
auto start = std::chrono::system_clock::now();
h();
h.done.wait();
std::cout << "NoData "
<< static_cast<double>(
std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now() - start)
.count()) /
static_cast<double>(numIters)
<< " usec/each\n";
}
void AtLaunch_WithData(int32_t numIters, int32_t vecSize) {
struct Helper {
explicit Helper(int32_t lim) : limit_(lim) {}
void operator()(std::vector<int32_t> v) {
if (++val_ == limit_) {
done.post();
} else {
at::launch([this, v = std::move(v)]() { (*this)(v); });
}
}
int val_{0};
int limit_;
Baton done;
};
Helper h(numIters);
std::vector<int32_t> v(vecSize, 0);
auto start = std::chrono::system_clock::now();
h(v);
h.done.wait();
std::cout << "WithData(" << vecSize << "): "
<< static_cast<double>(
std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now() - start)
.count()) /
static_cast<double>(numIters)
<< " usec/each\n";
}
int main(int argc, char** argv) {
int32_t N = 1000000;
AtLaunch_Base(N);
AtLaunch_WithData(N, 0);
AtLaunch_WithData(N, 4);
AtLaunch_WithData(N, 256);
return 0;
}