Compare commits

...

2 Commits

Author SHA1 Message Date
ccf1e63e49 make fixup 2024-03-26 18:22:41 +00:00
5f122c83be Improve the version check for torch to fail less often 2024-03-26 18:19:41 +00:00

View File

@ -57,12 +57,13 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
package_version = temp_version
package_exists = True
else:
package_exists = False
package_version = "N/A"
except ImportError:
# If the package can't be imported, it's not available
package_exists = False
else:
# For packages other than "torch", don't attempt the fallback and set as not available
package_version = "N/A"
elif return_version:
# For packages other than "torch", don't attempt the fallback
# However, we only mark the package as not available if the version is explicitly requested
package_exists = False
logger.debug(f"Detected {pkg_name} version: {package_version}")
if return_version:
@ -173,6 +174,14 @@ _torch_version = "N/A"
_torch_available = False
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
_torch_available, _torch_version = _is_package_available("torch", return_version=True)
if _torch_available and _torch_version == "N/A":
# Here we have the situation where the import package for torch exists, but we can't
# find the distribution package containing its version data. In this case, we import it and ask it directly.
import torch
_torch_version = torch.__version__
if "+" in _torch_version:
torch_version = _torch_version.split("+")[0]
else:
logger.info("Disabling PyTorch because USE_TF is set")
_torch_available = False