mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 08:00:58 +08:00 
			
		
		
		
	Use our thread-safe getenv wrappers. Pull Request resolved: https://github.com/pytorch/pytorch/pull/140200 Approved by: https://github.com/kwen2501, https://github.com/eqy
		
			
				
	
	
		
			99 lines
		
	
	
		
			2.4 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			99 lines
		
	
	
		
			2.4 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
#include <c10/util/Exception.h>
 | 
						|
#include <c10/util/env.h>
 | 
						|
#include <fmt/format.h>
 | 
						|
#include <cstdlib>
 | 
						|
#include <mutex>
 | 
						|
#include <shared_mutex>
 | 
						|
 | 
						|
namespace c10::utils {
 | 
						|
 | 
						|
static std::shared_mutex& get_env_mutex() {
 | 
						|
  static std::shared_mutex env_mutex;
 | 
						|
  return env_mutex;
 | 
						|
}
 | 
						|
 | 
						|
// Set an environment variable.
 | 
						|
void set_env(const char* name, const char* value, bool overwrite) {
 | 
						|
  std::lock_guard lk(get_env_mutex());
 | 
						|
#ifdef _MSC_VER
 | 
						|
#pragma warning(push)
 | 
						|
#pragma warning(disable : 4996)
 | 
						|
  if (!overwrite) {
 | 
						|
    // NOLINTNEXTLINE(concurrency-mt-unsafe)
 | 
						|
    if (std::getenv(name) != nullptr) {
 | 
						|
      return;
 | 
						|
    }
 | 
						|
  }
 | 
						|
  auto full_env_variable = fmt::format("{}={}", name, value);
 | 
						|
  // NOLINTNEXTLINE(concurrency-mt-unsafe)
 | 
						|
  auto err = putenv(full_env_variable.c_str());
 | 
						|
  TORCH_INTERNAL_ASSERT(
 | 
						|
      err == 0,
 | 
						|
      "putenv failed for environment \"",
 | 
						|
      name,
 | 
						|
      "\", the error is: ",
 | 
						|
      err);
 | 
						|
#pragma warning(pop)
 | 
						|
#else
 | 
						|
  // NOLINTNEXTLINE(concurrency-mt-unsafe)
 | 
						|
  auto err = setenv(name, value, static_cast<int>(overwrite));
 | 
						|
  TORCH_INTERNAL_ASSERT(
 | 
						|
      err == 0,
 | 
						|
      "setenv failed for environment \"",
 | 
						|
      name,
 | 
						|
      "\", the error is: ",
 | 
						|
      err);
 | 
						|
#endif
 | 
						|
  return;
 | 
						|
}
 | 
						|
 | 
						|
// Reads an environment variable and returns the content if it is set
 | 
						|
std::optional<std::string> get_env(const char* name) noexcept {
 | 
						|
  std::shared_lock lk(get_env_mutex());
 | 
						|
#ifdef _MSC_VER
 | 
						|
#pragma warning(push)
 | 
						|
#pragma warning(disable : 4996)
 | 
						|
#endif
 | 
						|
  // NOLINTNEXTLINE(concurrency-mt-unsafe)
 | 
						|
  auto envar = std::getenv(name);
 | 
						|
#ifdef _MSC_VER
 | 
						|
#pragma warning(pop)
 | 
						|
#endif
 | 
						|
  if (envar != nullptr) {
 | 
						|
    return std::string(envar);
 | 
						|
  }
 | 
						|
  return std::nullopt;
 | 
						|
}
 | 
						|
 | 
						|
// Checks an environment variable is set.
 | 
						|
bool has_env(const char* name) noexcept {
 | 
						|
  return get_env(name).has_value();
 | 
						|
}
 | 
						|
 | 
						|
// Reads an environment variable and returns
 | 
						|
// - optional<true>,              if set equal to "1"
 | 
						|
// - optional<false>,             if set equal to "0"
 | 
						|
// - nullopt,   otherwise
 | 
						|
//
 | 
						|
// NB:
 | 
						|
// Issues a warning if the value of the environment variable is not 0 or 1.
 | 
						|
std::optional<bool> check_env(const char* name) {
 | 
						|
  auto env_opt = get_env(name);
 | 
						|
  if (env_opt.has_value()) {
 | 
						|
    if (env_opt == "0") {
 | 
						|
      return false;
 | 
						|
    }
 | 
						|
    if (env_opt == "1") {
 | 
						|
      return true;
 | 
						|
    }
 | 
						|
    TORCH_WARN(
 | 
						|
        "Ignoring invalid value for boolean flag ",
 | 
						|
        name,
 | 
						|
        ": ",
 | 
						|
        *env_opt,
 | 
						|
        "valid values are 0 or 1.");
 | 
						|
  }
 | 
						|
  return std::nullopt;
 | 
						|
}
 | 
						|
} // namespace c10::utils
 |