mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE] Enable ruff's UP rules and autoformat optim/ (#105426)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105426 Approved by: https://github.com/malfet, https://github.com/albanD, https://github.com/aaronenyeshi, https://github.com/janeyx99
This commit is contained in:
committed by
PyTorch MergeBot
parent
be03a56955
commit
3721fa5612
@ -83,7 +83,7 @@ def test_biject_to(constraint_fn, args, is_cuda):
|
||||
t = biject_to(constraint)
|
||||
except NotImplementedError:
|
||||
pytest.skip('`biject_to` not implemented.')
|
||||
assert t.bijective, "biject_to({}) is not bijective".format(constraint)
|
||||
assert t.bijective, f"biject_to({constraint}) is not bijective"
|
||||
if constraint_fn is constraints.corr_cholesky:
|
||||
# (D * (D-1)) / 2 (where D = 4) = 6 (size of last dim)
|
||||
x = torch.randn(6, 6, dtype=torch.double)
|
||||
@ -93,12 +93,12 @@ def test_biject_to(constraint_fn, args, is_cuda):
|
||||
x = x.cuda()
|
||||
y = t(x)
|
||||
assert constraint.check(y).all(), '\n'.join([
|
||||
"Failed to biject_to({})".format(constraint),
|
||||
"x = {}".format(x),
|
||||
"biject_to(...)(x) = {}".format(y),
|
||||
f"Failed to biject_to({constraint})",
|
||||
f"x = {x}",
|
||||
f"biject_to(...)(x) = {y}",
|
||||
])
|
||||
x2 = t.inv(y)
|
||||
assert torch.allclose(x, x2), "Error in biject_to({}) inverse".format(constraint)
|
||||
assert torch.allclose(x, x2), f"Error in biject_to({constraint}) inverse"
|
||||
|
||||
j = t.log_abs_det_jacobian(x, y)
|
||||
assert j.shape == x.shape[:x.dim() - t.domain.event_dim]
|
||||
@ -119,10 +119,10 @@ def test_transform_to(constraint_fn, args, is_cuda):
|
||||
if is_cuda:
|
||||
x = x.cuda()
|
||||
y = t(x)
|
||||
assert constraint.check(y).all(), "Failed to transform_to({})".format(constraint)
|
||||
assert constraint.check(y).all(), f"Failed to transform_to({constraint})"
|
||||
x2 = t.inv(y)
|
||||
y2 = t(x2)
|
||||
assert torch.allclose(y, y2), "Error in transform_to({}) pseudoinverse".format(constraint)
|
||||
assert torch.allclose(y, y2), f"Error in transform_to({constraint}) pseudoinverse"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -862,7 +862,7 @@ class TestDistributions(DistributionsTestCase):
|
||||
bins = samples.reshape((num_bins, samples_per_bin)).mean(axis=1)
|
||||
stddev = samples_per_bin ** -0.5
|
||||
threshold = stddev * scipy.special.erfinv(1 - 2 * failure_rate / num_bins)
|
||||
message = '{}.sample() is biased:\n{}'.format(message, bins)
|
||||
message = f'{message}.sample() is biased:\n{bins}'
|
||||
for bias in bins:
|
||||
self.assertLess(-threshold, bias, message)
|
||||
self.assertLess(bias, threshold, message)
|
||||
@ -971,7 +971,7 @@ class TestDistributions(DistributionsTestCase):
|
||||
if isinstance(Dist, type) and issubclass(Dist, Distribution) \
|
||||
and Dist is not Distribution and Dist is not ExponentialFamily:
|
||||
self.assertIn(Dist, distributions_with_examples,
|
||||
"Please add {} to the EXAMPLES list in test_distributions.py".format(Dist.__name__))
|
||||
f"Please add {Dist.__name__} to the EXAMPLES list in test_distributions.py")
|
||||
|
||||
def test_support_attributes(self):
|
||||
for Dist, params in EXAMPLES:
|
||||
@ -1120,7 +1120,7 @@ class TestDistributions(DistributionsTestCase):
|
||||
for prob in [0.01, 0.18, 0.8]:
|
||||
self._check_sampler_discrete(Geometric(prob),
|
||||
scipy.stats.geom(p=prob, loc=-1),
|
||||
'Geometric(prob={})'.format(prob))
|
||||
f'Geometric(prob={prob})')
|
||||
|
||||
def test_binomial(self):
|
||||
p = torch.arange(0.05, 1, 0.1).requires_grad_()
|
||||
@ -1136,7 +1136,7 @@ class TestDistributions(DistributionsTestCase):
|
||||
for count in [2, 10, 100, 500]:
|
||||
self._check_sampler_discrete(Binomial(total_count=count, probs=prob),
|
||||
scipy.stats.binom(count, prob),
|
||||
'Binomial(total_count={}, probs={})'.format(count, prob))
|
||||
f'Binomial(total_count={count}, probs={prob})')
|
||||
|
||||
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
|
||||
def test_binomial_log_prob_and_entropy(self):
|
||||
@ -1431,7 +1431,7 @@ class TestDistributions(DistributionsTestCase):
|
||||
for rate in [0.1, 1.0, 5.0]:
|
||||
self._check_sampler_discrete(Poisson(rate),
|
||||
scipy.stats.poisson(rate),
|
||||
'Poisson(lambda={})'.format(rate),
|
||||
f'Poisson(lambda={rate})',
|
||||
failure_rate=1e-3)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
|
||||
@ -1441,7 +1441,7 @@ class TestDistributions(DistributionsTestCase):
|
||||
for rate in [0.12, 0.9, 4.0]:
|
||||
self._check_sampler_discrete(Poisson(torch.tensor([rate]).cuda()),
|
||||
scipy.stats.poisson(rate),
|
||||
'Poisson(lambda={}, cuda)'.format(rate),
|
||||
f'Poisson(lambda={rate}, cuda)',
|
||||
failure_rate=1e-3)
|
||||
|
||||
def test_relaxed_bernoulli(self):
|
||||
@ -1476,7 +1476,7 @@ class TestDistributions(DistributionsTestCase):
|
||||
for probs, temp in product([0.1, 0.2, 0.8], [0.1, 1.0, 10.0]):
|
||||
self._check_sampler_discrete(Rounded(RelaxedBernoulli(temp, probs)),
|
||||
scipy.stats.bernoulli(probs),
|
||||
'Rounded(RelaxedBernoulli(temp={}, probs={}))'.format(temp, probs),
|
||||
f'Rounded(RelaxedBernoulli(temp={temp}, probs={probs}))',
|
||||
failure_rate=1e-3)
|
||||
|
||||
for probs in [0.001, 0.2, 0.999]:
|
||||
@ -1534,7 +1534,7 @@ class TestDistributions(DistributionsTestCase):
|
||||
for probs, temp in product([torch.tensor([0.1, 0.9]), torch.tensor([0.2, 0.2, 0.6])], [0.1, 1.0, 10.0]):
|
||||
self._check_sampler_discrete(ArgMax(RelaxedOneHotCategorical(temp, probs)),
|
||||
ScipyCategorical(scipy.stats.multinomial(1, probs)),
|
||||
'Rounded(RelaxedOneHotCategorical(temp={}, probs={}))'.format(temp, probs),
|
||||
f'Rounded(RelaxedOneHotCategorical(temp={temp}, probs={probs}))',
|
||||
failure_rate=1e-3)
|
||||
|
||||
for probs in [torch.tensor([0.1, 0.9]), torch.tensor([0.2, 0.2, 0.6])]:
|
||||
@ -1588,7 +1588,7 @@ class TestDistributions(DistributionsTestCase):
|
||||
for concentration in [0.03, 0.3, 1.0, 10.0, 100.0]:
|
||||
self._check_sampler_sampler(VonMises(loc, concentration),
|
||||
scipy.stats.vonmises(loc=loc, kappa=concentration),
|
||||
"VonMises(loc={}, concentration={})".format(loc, concentration),
|
||||
f"VonMises(loc={loc}, concentration={concentration})",
|
||||
num_samples=int(1e5), circular=True)
|
||||
|
||||
def test_vonmises_logprob(self):
|
||||
@ -1694,7 +1694,7 @@ class TestDistributions(DistributionsTestCase):
|
||||
for std in [0.1, 1.0, 10.0]:
|
||||
self._check_sampler_sampler(HalfNormal(std),
|
||||
scipy.stats.halfnorm(scale=std),
|
||||
'HalfNormal(scale={})'.format(std))
|
||||
f'HalfNormal(scale={std})')
|
||||
|
||||
def test_lognormal(self):
|
||||
mean = torch.randn(5, 5, requires_grad=True)
|
||||
@ -1746,7 +1746,7 @@ class TestDistributions(DistributionsTestCase):
|
||||
for mean, std in product([-1.0, 0.0, 1.0], [0.1, 1.0, 10.0]):
|
||||
self._check_sampler_sampler(LogNormal(mean, std),
|
||||
scipy.stats.lognorm(scale=math.exp(mean), s=std),
|
||||
'LogNormal(loc={}, scale={})'.format(mean, std))
|
||||
f'LogNormal(loc={mean}, scale={std})')
|
||||
|
||||
def test_logisticnormal(self):
|
||||
set_rng_seed(1) # see Note [Randomized statistical tests]
|
||||
@ -1814,7 +1814,7 @@ class TestDistributions(DistributionsTestCase):
|
||||
std_th = torch.tensor(np.sqrt(np.diag(cov)))
|
||||
self._check_sampler_sampler(
|
||||
LogisticNormal(mean_th, std_th), ref_dist,
|
||||
'LogisticNormal(loc={}, scale={})'.format(mean_th, std_th),
|
||||
f'LogisticNormal(loc={mean_th}, scale={std_th})',
|
||||
multivariate=True)
|
||||
|
||||
def test_mixture_same_family_shape(self):
|
||||
@ -1958,7 +1958,7 @@ class TestDistributions(DistributionsTestCase):
|
||||
for loc, scale in product([-1.0, 0.0, 1.0], [0.1, 1.0, 10.0]):
|
||||
self._check_sampler_sampler(Normal(loc, scale),
|
||||
scipy.stats.norm(loc=loc, scale=scale),
|
||||
'Normal(mean={}, std={})'.format(loc, scale))
|
||||
f'Normal(mean={loc}, std={scale})')
|
||||
|
||||
def test_lowrank_multivariate_normal_shape(self):
|
||||
mean = torch.randn(5, 3, requires_grad=True)
|
||||
@ -2191,15 +2191,15 @@ class TestDistributions(DistributionsTestCase):
|
||||
|
||||
self._check_sampler_sampler(MultivariateNormal(mean, cov),
|
||||
scipy.stats.multivariate_normal(mean.detach().numpy(), cov.detach().numpy()),
|
||||
'MultivariateNormal(loc={}, cov={})'.format(mean, cov),
|
||||
f'MultivariateNormal(loc={mean}, cov={cov})',
|
||||
multivariate=True)
|
||||
self._check_sampler_sampler(MultivariateNormal(mean, precision_matrix=prec),
|
||||
scipy.stats.multivariate_normal(mean.detach().numpy(), cov.detach().numpy()),
|
||||
'MultivariateNormal(loc={}, atol={})'.format(mean, prec),
|
||||
f'MultivariateNormal(loc={mean}, atol={prec})',
|
||||
multivariate=True)
|
||||
self._check_sampler_sampler(MultivariateNormal(mean, scale_tril=scale_tril),
|
||||
scipy.stats.multivariate_normal(mean.detach().numpy(), cov.detach().numpy()),
|
||||
'MultivariateNormal(loc={}, scale_tril={})'.format(mean, scale_tril),
|
||||
f'MultivariateNormal(loc={mean}, scale_tril={scale_tril})',
|
||||
multivariate=True)
|
||||
|
||||
def test_multivariate_normal_properties(self):
|
||||
@ -2352,15 +2352,15 @@ class TestDistributions(DistributionsTestCase):
|
||||
|
||||
self._check_sampler_sampler(Wishart(df, cov),
|
||||
ref_dist,
|
||||
'Wishart(df={}, covariance_matrix={})'.format(df, cov),
|
||||
f'Wishart(df={df}, covariance_matrix={cov})',
|
||||
multivariate=True)
|
||||
self._check_sampler_sampler(Wishart(df, precision_matrix=prec),
|
||||
ref_dist,
|
||||
'Wishart(df={}, precision_matrix={})'.format(df, prec),
|
||||
f'Wishart(df={df}, precision_matrix={prec})',
|
||||
multivariate=True)
|
||||
self._check_sampler_sampler(Wishart(df, scale_tril=scale_tril),
|
||||
ref_dist,
|
||||
'Wishart(df={}, scale_tril={})'.format(df, scale_tril),
|
||||
f'Wishart(df={df}, scale_tril={scale_tril})',
|
||||
multivariate=True)
|
||||
|
||||
def test_wishart_properties(self):
|
||||
@ -2431,7 +2431,7 @@ class TestDistributions(DistributionsTestCase):
|
||||
for rate in [1e-5, 1.0, 10.]:
|
||||
self._check_sampler_sampler(Exponential(rate),
|
||||
scipy.stats.expon(scale=1. / rate),
|
||||
'Exponential(rate={})'.format(rate))
|
||||
f'Exponential(rate={rate})')
|
||||
|
||||
def test_laplace(self):
|
||||
loc = torch.randn(5, 5, requires_grad=True)
|
||||
@ -2482,7 +2482,7 @@ class TestDistributions(DistributionsTestCase):
|
||||
for loc, scale in product([-1.0, 0.0, 1.0], [0.1, 1.0, 10.0]):
|
||||
self._check_sampler_sampler(Laplace(loc, scale),
|
||||
scipy.stats.laplace(loc=loc, scale=scale),
|
||||
'Laplace(loc={}, scale={})'.format(loc, scale))
|
||||
f'Laplace(loc={loc}, scale={scale})')
|
||||
|
||||
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
|
||||
def test_gamma_shape(self):
|
||||
@ -2533,7 +2533,7 @@ class TestDistributions(DistributionsTestCase):
|
||||
for alpha, beta in product([0.1, 1.0, 5.0], [0.1, 1.0, 10.0]):
|
||||
self._check_sampler_sampler(Gamma(alpha, beta),
|
||||
scipy.stats.gamma(alpha, scale=1.0 / beta),
|
||||
'Gamma(concentration={}, rate={})'.format(alpha, beta))
|
||||
f'Gamma(concentration={alpha}, rate={beta})')
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
|
||||
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
|
||||
@ -2543,7 +2543,7 @@ class TestDistributions(DistributionsTestCase):
|
||||
a, b = torch.tensor([alpha]).cuda(), torch.tensor([beta]).cuda()
|
||||
self._check_sampler_sampler(Gamma(a, b),
|
||||
scipy.stats.gamma(alpha, scale=1.0 / beta),
|
||||
'Gamma(alpha={}, beta={})'.format(alpha, beta),
|
||||
f'Gamma(alpha={alpha}, beta={beta})',
|
||||
failure_rate=1e-4)
|
||||
|
||||
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
|
||||
@ -2575,7 +2575,7 @@ class TestDistributions(DistributionsTestCase):
|
||||
for scale, alpha in product([0.1, 1.0, 5.0], [0.1, 1.0, 10.0]):
|
||||
self._check_sampler_sampler(Pareto(scale, alpha),
|
||||
scipy.stats.pareto(alpha, scale=scale),
|
||||
'Pareto(scale={}, alpha={})'.format(scale, alpha))
|
||||
f'Pareto(scale={scale}, alpha={alpha})')
|
||||
|
||||
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
|
||||
def test_gumbel(self):
|
||||
@ -2616,7 +2616,7 @@ class TestDistributions(DistributionsTestCase):
|
||||
for loc, scale in product([-5.0, -1.0, -0.1, 0.1, 1.0, 5.0], [0.1, 1.0, 10.0]):
|
||||
self._check_sampler_sampler(Gumbel(loc, scale),
|
||||
scipy.stats.gumbel_r(loc=loc, scale=scale),
|
||||
'Gumbel(loc={}, scale={})'.format(loc, scale))
|
||||
f'Gumbel(loc={loc}, scale={scale})')
|
||||
|
||||
def test_kumaraswamy_shape(self):
|
||||
concentration1 = torch.randn(2, 3).abs().requires_grad_()
|
||||
@ -2646,13 +2646,13 @@ class TestDistributions(DistributionsTestCase):
|
||||
error = (expected - actual).abs()
|
||||
max_error = max(error[error == error])
|
||||
self.assertLess(max_error, 0.01,
|
||||
"Kumaraswamy example {}/{}, incorrect .mean".format(i + 1, len(cases)))
|
||||
f"Kumaraswamy example {i + 1}/{len(cases)}, incorrect .mean")
|
||||
expected = samples.var(0)
|
||||
actual = m.variance
|
||||
error = (expected - actual).abs()
|
||||
max_error = max(error[error == error])
|
||||
self.assertLess(max_error, 0.01,
|
||||
"Kumaraswamy example {}/{}, incorrect .variance".format(i + 1, len(cases)))
|
||||
f"Kumaraswamy example {i + 1}/{len(cases)}, incorrect .variance")
|
||||
|
||||
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
|
||||
def test_fishersnedecor(self):
|
||||
@ -2683,7 +2683,7 @@ class TestDistributions(DistributionsTestCase):
|
||||
for df1, df2 in product([0.1, 0.5, 1.0, 5.0, 10.0], [0.1, 0.5, 1.0, 5.0, 10.0]):
|
||||
self._check_sampler_sampler(FisherSnedecor(df1, df2),
|
||||
scipy.stats.f(df1, df2),
|
||||
'FisherSnedecor(loc={}, scale={})'.format(df1, df2))
|
||||
f'FisherSnedecor(loc={df1}, scale={df2})')
|
||||
|
||||
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
|
||||
def test_chi2_shape(self):
|
||||
@ -2710,7 +2710,7 @@ class TestDistributions(DistributionsTestCase):
|
||||
for df in [0.1, 1.0, 5.0]:
|
||||
self._check_sampler_sampler(Chi2(df),
|
||||
scipy.stats.chi2(df),
|
||||
'Chi2(df={})'.format(df))
|
||||
f'Chi2(df={df})')
|
||||
|
||||
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
|
||||
def test_studentT(self):
|
||||
@ -2740,7 +2740,7 @@ class TestDistributions(DistributionsTestCase):
|
||||
for df, loc, scale in product([0.1, 1.0, 5.0, 10.0], [-1.0, 0.0, 1.0], [0.1, 1.0, 10.0]):
|
||||
self._check_sampler_sampler(StudentT(df=df, loc=loc, scale=scale),
|
||||
scipy.stats.t(df=df, loc=loc, scale=scale),
|
||||
'StudentT(df={}, loc={}, scale={})'.format(df, loc, scale))
|
||||
f'StudentT(df={df}, loc={loc}, scale={scale})')
|
||||
|
||||
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
|
||||
def test_studentT_log_prob(self):
|
||||
@ -2793,7 +2793,7 @@ class TestDistributions(DistributionsTestCase):
|
||||
alpha = torch.exp(torch.randn(3))
|
||||
self._check_sampler_sampler(Dirichlet(alpha),
|
||||
scipy.stats.dirichlet(alpha.numpy()),
|
||||
'Dirichlet(alpha={})'.format(list(alpha)),
|
||||
f'Dirichlet(alpha={list(alpha)})',
|
||||
multivariate=True)
|
||||
|
||||
def test_dirichlet_mode(self):
|
||||
@ -2837,11 +2837,11 @@ class TestDistributions(DistributionsTestCase):
|
||||
for con1, con0 in product([0.1, 1.0, 10.0], [0.1, 1.0, 10.0]):
|
||||
self._check_sampler_sampler(Beta(con1, con0),
|
||||
scipy.stats.beta(con1, con0),
|
||||
'Beta(alpha={}, beta={})'.format(con1, con0))
|
||||
f'Beta(alpha={con1}, beta={con0})')
|
||||
# Check that small alphas do not cause NANs.
|
||||
for Tensor in [torch.FloatTensor, torch.DoubleTensor]:
|
||||
x = Beta(Tensor([1e-6]), Tensor([1e-6])).sample()[0]
|
||||
self.assertTrue(np.isfinite(x) and x > 0, 'Invalid Beta.sample(): {}'.format(x))
|
||||
self.assertTrue(np.isfinite(x) and x > 0, f'Invalid Beta.sample(): {x}')
|
||||
|
||||
def test_beta_underflow(self):
|
||||
# For low values of (alpha, beta), the gamma samples can underflow
|
||||
@ -2997,10 +2997,10 @@ class TestDistributions(DistributionsTestCase):
|
||||
continue
|
||||
rel_error = torch.abs(actual - samples) / (1e-10 + torch.abs(samples))
|
||||
self.assertLess(rel_error.max(), 1e-4, msg='\n'.join([
|
||||
'{} example {}/{}, icdf(cdf(x)) != x'.format(Dist.__name__, i + 1, len(params)),
|
||||
'x = {}'.format(samples),
|
||||
'cdf(x) = {}'.format(cdf),
|
||||
'icdf(cdf(x)) = {}'.format(actual),
|
||||
f'{Dist.__name__} example {i + 1}/{len(params)}, icdf(cdf(x)) != x',
|
||||
f'x = {samples}',
|
||||
f'cdf(x) = {cdf}',
|
||||
f'icdf(cdf(x)) = {actual}',
|
||||
]))
|
||||
|
||||
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
|
||||
@ -3029,11 +3029,11 @@ class TestDistributions(DistributionsTestCase):
|
||||
continue
|
||||
cdfs_derivative = grad(cdfs.sum(), [samples])[0] # this should not be wrapped in torch.abs()
|
||||
self.assertEqual(cdfs_derivative, pdfs, msg='\n'.join([
|
||||
'{} example {}/{}, d(cdf)/dx != pdf(x)'.format(Dist.__name__, i + 1, len(params)),
|
||||
'x = {}'.format(samples),
|
||||
'cdf = {}'.format(cdfs),
|
||||
'pdf = {}'.format(pdfs),
|
||||
'grad(cdf) = {}'.format(cdfs_derivative),
|
||||
f'{Dist.__name__} example {i + 1}/{len(params)}, d(cdf)/dx != pdf(x)',
|
||||
f'x = {samples}',
|
||||
f'cdf = {cdfs}',
|
||||
f'pdf = {pdfs}',
|
||||
f'grad(cdf) = {cdfs_derivative}',
|
||||
]))
|
||||
|
||||
def test_valid_parameter_broadcasting(self):
|
||||
@ -3144,13 +3144,13 @@ class TestDistributions(DistributionsTestCase):
|
||||
for dist, expected_size in valid_examples:
|
||||
actual_size = dist.sample().size()
|
||||
self.assertEqual(actual_size, expected_size,
|
||||
msg='{} actual size: {} != expected size: {}'.format(dist, actual_size, expected_size))
|
||||
msg=f'{dist} actual size: {actual_size} != expected size: {expected_size}')
|
||||
|
||||
sample_shape = torch.Size((2,))
|
||||
expected_size = sample_shape + expected_size
|
||||
actual_size = dist.sample(sample_shape).size()
|
||||
self.assertEqual(actual_size, expected_size,
|
||||
msg='{} actual size: {} != expected size: {}'.format(dist, actual_size, expected_size))
|
||||
msg=f'{dist} actual size: {actual_size} != expected size: {expected_size}')
|
||||
|
||||
def test_invalid_parameter_broadcasting(self):
|
||||
# invalid broadcasting cases; should throw error
|
||||
@ -3303,13 +3303,13 @@ class TestRsample(DistributionsTestCase):
|
||||
expected_grad = -cdf_alpha / cdf_x
|
||||
rel_error = np.abs(actual_grad - expected_grad) / (expected_grad + 1e-30)
|
||||
self.assertLess(np.max(rel_error), 0.0005, '\n'.join([
|
||||
'Bad gradient dx/alpha for x ~ Gamma({}, 1)'.format(alpha),
|
||||
'x {}'.format(x),
|
||||
'expected {}'.format(expected_grad),
|
||||
'actual {}'.format(actual_grad),
|
||||
'rel error {}'.format(rel_error),
|
||||
'max error {}'.format(rel_error.max()),
|
||||
'at alpha={}, x={}'.format(alpha, x[rel_error.argmax()]),
|
||||
f'Bad gradient dx/alpha for x ~ Gamma({alpha}, 1)',
|
||||
f'x {x}',
|
||||
f'expected {expected_grad}',
|
||||
f'actual {actual_grad}',
|
||||
f'rel error {rel_error}',
|
||||
f'max error {rel_error.max()}',
|
||||
f'at alpha={alpha}, x={x[rel_error.argmax()]}',
|
||||
]))
|
||||
|
||||
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
|
||||
@ -3331,12 +3331,12 @@ class TestRsample(DistributionsTestCase):
|
||||
expected_grad = -cdf_df / cdf_x
|
||||
rel_error = np.abs(actual_grad - expected_grad) / (expected_grad + 1e-30)
|
||||
self.assertLess(np.max(rel_error), 0.001, '\n'.join([
|
||||
'Bad gradient dx/ddf for x ~ Chi2({})'.format(df),
|
||||
'x {}'.format(x),
|
||||
'expected {}'.format(expected_grad),
|
||||
'actual {}'.format(actual_grad),
|
||||
'rel error {}'.format(rel_error),
|
||||
'max error {}'.format(rel_error.max()),
|
||||
f'Bad gradient dx/ddf for x ~ Chi2({df})',
|
||||
f'x {x}',
|
||||
f'expected {expected_grad}',
|
||||
f'actual {actual_grad}',
|
||||
f'rel error {rel_error}',
|
||||
f'max error {rel_error.max()}',
|
||||
]))
|
||||
|
||||
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
|
||||
@ -3361,13 +3361,13 @@ class TestRsample(DistributionsTestCase):
|
||||
expected_grad = -cdf_alpha / cdf_x
|
||||
rel_error = np.abs(actual_grad - expected_grad) / (expected_grad + 1e-30)
|
||||
self.assertLess(np.max(rel_error), 0.001, '\n'.join([
|
||||
'Bad gradient dx[0]/dalpha[0] for Dirichlet([{}, {}, {}])'.format(a0, a1, a2),
|
||||
'x {}'.format(x),
|
||||
'expected {}'.format(expected_grad),
|
||||
'actual {}'.format(actual_grad),
|
||||
'rel error {}'.format(rel_error),
|
||||
'max error {}'.format(rel_error.max()),
|
||||
'at x={}'.format(x[rel_error.argmax()]),
|
||||
f'Bad gradient dx[0]/dalpha[0] for Dirichlet([{a0}, {a1}, {a2}])',
|
||||
f'x {x}',
|
||||
f'expected {expected_grad}',
|
||||
f'actual {actual_grad}',
|
||||
f'rel error {rel_error}',
|
||||
f'max error {rel_error.max()}',
|
||||
f'at x={x[rel_error.argmax()]}',
|
||||
]))
|
||||
|
||||
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
|
||||
@ -3391,13 +3391,13 @@ class TestRsample(DistributionsTestCase):
|
||||
expected_grad = -cdf_alpha / cdf_x
|
||||
rel_error = np.abs(actual_grad - expected_grad) / (expected_grad + 1e-30)
|
||||
self.assertLess(np.max(rel_error), 0.005, '\n'.join([
|
||||
'Bad gradient dx/dcon1 for x ~ Beta({}, {})'.format(con1, con0),
|
||||
'x {}'.format(x),
|
||||
'expected {}'.format(expected_grad),
|
||||
'actual {}'.format(actual_grad),
|
||||
'rel error {}'.format(rel_error),
|
||||
'max error {}'.format(rel_error.max()),
|
||||
'at x = {}'.format(x[rel_error.argmax()]),
|
||||
f'Bad gradient dx/dcon1 for x ~ Beta({con1}, {con0})',
|
||||
f'x {x}',
|
||||
f'expected {expected_grad}',
|
||||
f'actual {actual_grad}',
|
||||
f'rel error {rel_error}',
|
||||
f'max error {rel_error.max()}',
|
||||
f'at x = {x[rel_error.argmax()]}',
|
||||
]))
|
||||
|
||||
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
|
||||
@ -3421,13 +3421,13 @@ class TestRsample(DistributionsTestCase):
|
||||
expected_grad = -cdf_beta / cdf_x
|
||||
rel_error = np.abs(actual_grad - expected_grad) / (expected_grad + 1e-30)
|
||||
self.assertLess(np.max(rel_error), 0.005, '\n'.join([
|
||||
'Bad gradient dx/dcon0 for x ~ Beta({}, {})'.format(con1, con0),
|
||||
'x {}'.format(x),
|
||||
'expected {}'.format(expected_grad),
|
||||
'actual {}'.format(actual_grad),
|
||||
'rel error {}'.format(rel_error),
|
||||
'max error {}'.format(rel_error.max()),
|
||||
'at x = {!r}'.format(x[rel_error.argmax()]),
|
||||
f'Bad gradient dx/dcon0 for x ~ Beta({con1}, {con0})',
|
||||
f'x {x}',
|
||||
f'expected {expected_grad}',
|
||||
f'actual {actual_grad}',
|
||||
f'rel error {rel_error}',
|
||||
f'max error {rel_error.max()}',
|
||||
f'at x = {x[rel_error.argmax()]!r}',
|
||||
]))
|
||||
|
||||
def test_dirichlet_multivariate(self):
|
||||
@ -3485,8 +3485,8 @@ class TestRsample(DistributionsTestCase):
|
||||
# expression in terms of log_prob rather than the less numerically stable log_prob.exp().
|
||||
error = dlogp_da + (dlogp_dx * v).sum(-1) + div_v
|
||||
self.assertLess(torch.abs(error).max(), 0.005, '\n'.join([
|
||||
'Dirichlet([{}, {}, {}]) gradient violates continuity equation:'.format(a1, a2, a3),
|
||||
'error = {}'.format(error),
|
||||
f'Dirichlet([{a1}, {a2}, {a3}]) gradient violates continuity equation:',
|
||||
f'error = {error}',
|
||||
]))
|
||||
|
||||
|
||||
@ -4147,9 +4147,9 @@ class TestKL(DistributionsTestCase):
|
||||
if error[error == error].max() < self.precision:
|
||||
break
|
||||
self.assertLess(error[error == error].max(), self.precision, '\n'.join([
|
||||
'Incorrect KL({}, {}).'.format(type(p).__name__, type(q).__name__),
|
||||
'Expected ({} Monte Carlo samples): {}'.format(denominator, expected),
|
||||
'Actual (analytic): {}'.format(actual),
|
||||
f'Incorrect KL({type(p).__name__}, {type(q).__name__}).',
|
||||
f'Expected ({denominator} Monte Carlo samples): {expected}',
|
||||
f'Actual (analytic): {actual}',
|
||||
]))
|
||||
|
||||
# Multivariate normal has a separate Monte Carlo based test due to the requirement of random generation of
|
||||
@ -4174,9 +4174,9 @@ class TestKL(DistributionsTestCase):
|
||||
if error[error == error].max() < self.precision:
|
||||
break
|
||||
self.assertLess(error[error == error].max(), self.precision, '\n'.join([
|
||||
'Incorrect KL(MultivariateNormal, MultivariateNormal) instance {}/{}'.format(i + 1, n),
|
||||
'Expected ({} Monte Carlo sample): {}'.format(denominator, expected),
|
||||
'Actual (analytic): {}'.format(actual),
|
||||
f'Incorrect KL(MultivariateNormal, MultivariateNormal) instance {i + 1}/{n}',
|
||||
f'Expected ({denominator} Monte Carlo sample): {expected}',
|
||||
f'Actual (analytic): {actual}',
|
||||
]))
|
||||
|
||||
def test_kl_multivariate_normal_batched(self):
|
||||
@ -4223,23 +4223,23 @@ class TestKL(DistributionsTestCase):
|
||||
|
||||
error_lowrank_lowrank = torch.abs(actual_lowrank_lowrank - expected).max()
|
||||
self.assertLess(error_lowrank_lowrank, self.precision, '\n'.join([
|
||||
'Incorrect KL(LowRankMultivariateNormal, LowRankMultivariateNormal) instance {}/{}'.format(i + 1, n),
|
||||
'Expected (from KL MultivariateNormal): {}'.format(expected),
|
||||
'Actual (analytic): {}'.format(actual_lowrank_lowrank),
|
||||
f'Incorrect KL(LowRankMultivariateNormal, LowRankMultivariateNormal) instance {i + 1}/{n}',
|
||||
f'Expected (from KL MultivariateNormal): {expected}',
|
||||
f'Actual (analytic): {actual_lowrank_lowrank}',
|
||||
]))
|
||||
|
||||
error_lowrank_full = torch.abs(actual_lowrank_full - expected).max()
|
||||
self.assertLess(error_lowrank_full, self.precision, '\n'.join([
|
||||
'Incorrect KL(LowRankMultivariateNormal, MultivariateNormal) instance {}/{}'.format(i + 1, n),
|
||||
'Expected (from KL MultivariateNormal): {}'.format(expected),
|
||||
'Actual (analytic): {}'.format(actual_lowrank_full),
|
||||
f'Incorrect KL(LowRankMultivariateNormal, MultivariateNormal) instance {i + 1}/{n}',
|
||||
f'Expected (from KL MultivariateNormal): {expected}',
|
||||
f'Actual (analytic): {actual_lowrank_full}',
|
||||
]))
|
||||
|
||||
error_full_lowrank = torch.abs(actual_full_lowrank - expected).max()
|
||||
self.assertLess(error_full_lowrank, self.precision, '\n'.join([
|
||||
'Incorrect KL(MultivariateNormal, LowRankMultivariateNormal) instance {}/{}'.format(i + 1, n),
|
||||
'Expected (from KL MultivariateNormal): {}'.format(expected),
|
||||
'Actual (analytic): {}'.format(actual_full_lowrank),
|
||||
f'Incorrect KL(MultivariateNormal, LowRankMultivariateNormal) instance {i + 1}/{n}',
|
||||
f'Expected (from KL MultivariateNormal): {expected}',
|
||||
f'Actual (analytic): {actual_full_lowrank}',
|
||||
]))
|
||||
|
||||
def test_kl_lowrank_multivariate_normal_batched(self):
|
||||
@ -4261,16 +4261,16 @@ class TestKL(DistributionsTestCase):
|
||||
actual = kl_divergence(p, q)
|
||||
expected = _kl_expfamily_expfamily(p, q)
|
||||
self.assertEqual(actual, expected, msg='\n'.join([
|
||||
'Incorrect KL({}, {}).'.format(type(p).__name__, type(q).__name__),
|
||||
'Expected (using Bregman Divergence) {}'.format(expected),
|
||||
'Actual (analytic) {}'.format(actual),
|
||||
'max error = {}'.format(torch.abs(actual - expected).max())
|
||||
f'Incorrect KL({type(p).__name__}, {type(q).__name__}).',
|
||||
f'Expected (using Bregman Divergence) {expected}',
|
||||
f'Actual (analytic) {actual}',
|
||||
f'max error = {torch.abs(actual - expected).max()}'
|
||||
]))
|
||||
|
||||
def test_kl_infinite(self):
|
||||
for p, q in self.infinite_examples:
|
||||
self.assertTrue((kl_divergence(p, q) == inf).all(),
|
||||
'Incorrect KL({}, {})'.format(type(p).__name__, type(q).__name__))
|
||||
f'Incorrect KL({type(p).__name__}, {type(q).__name__})')
|
||||
|
||||
def test_kl_edgecases(self):
|
||||
self.assertEqual(kl_divergence(Bernoulli(0), Bernoulli(0)), 0)
|
||||
@ -4287,9 +4287,9 @@ class TestKL(DistributionsTestCase):
|
||||
continue
|
||||
expected_shape = dist.batch_shape if dist.batch_shape else torch.Size()
|
||||
self.assertEqual(kl.shape, expected_shape, msg='\n'.join([
|
||||
'{} example {}/{}'.format(Dist.__name__, i + 1, len(params)),
|
||||
'Expected {}'.format(expected_shape),
|
||||
'Actual {}'.format(kl.shape),
|
||||
f'{Dist.__name__} example {i + 1}/{len(params)}',
|
||||
f'Expected {expected_shape}',
|
||||
f'Actual {kl.shape}',
|
||||
]))
|
||||
|
||||
def test_kl_transformed(self):
|
||||
@ -4316,10 +4316,10 @@ class TestKL(DistributionsTestCase):
|
||||
ignore = (expected == inf) | (expected == -inf)
|
||||
expected[ignore] = actual[ignore]
|
||||
self.assertEqual(actual, expected, atol=0.2, rtol=0, msg='\n'.join([
|
||||
'{} example {}/{}, incorrect .entropy().'.format(Dist.__name__, i + 1, len(params)),
|
||||
'Expected (monte carlo) {}'.format(expected),
|
||||
'Actual (analytic) {}'.format(actual),
|
||||
'max error = {}'.format(torch.abs(actual - expected).max()),
|
||||
f'{Dist.__name__} example {i + 1}/{len(params)}, incorrect .entropy().',
|
||||
f'Expected (monte carlo) {expected}',
|
||||
f'Actual (analytic) {actual}',
|
||||
f'max error = {torch.abs(actual - expected).max()}',
|
||||
]))
|
||||
|
||||
def test_entropy_exponential_family(self):
|
||||
@ -4337,10 +4337,10 @@ class TestKL(DistributionsTestCase):
|
||||
except NotImplementedError:
|
||||
continue
|
||||
self.assertEqual(actual, expected, msg='\n'.join([
|
||||
'{} example {}/{}, incorrect .entropy().'.format(Dist.__name__, i + 1, len(params)),
|
||||
'Expected (Bregman Divergence) {}'.format(expected),
|
||||
'Actual (analytic) {}'.format(actual),
|
||||
'max error = {}'.format(torch.abs(actual - expected).max())
|
||||
f'{Dist.__name__} example {i + 1}/{len(params)}, incorrect .entropy().',
|
||||
f'Expected (Bregman Divergence) {expected}',
|
||||
f'Actual (analytic) {actual}',
|
||||
f'max error = {torch.abs(actual - expected).max()}'
|
||||
]))
|
||||
|
||||
|
||||
@ -4632,7 +4632,7 @@ class TestLazyLogitsInitialization(DistributionsTestCase):
|
||||
dist = Dist(**param)
|
||||
# Create new instance to generate a valid sample
|
||||
dist.log_prob(Dist(**param).sample())
|
||||
message = 'Failed for {} example 0/{}'.format(Dist.__name__, len(params))
|
||||
message = f'Failed for {Dist.__name__} example 0/{len(params)}'
|
||||
self.assertNotIn('probs', dist.__dict__, msg=message)
|
||||
try:
|
||||
dist.enumerate_support()
|
||||
@ -4649,7 +4649,7 @@ class TestLazyLogitsInitialization(DistributionsTestCase):
|
||||
continue
|
||||
dist = Dist(**param)
|
||||
dist.sample()
|
||||
message = 'Failed for {} example 0/{}'.format(Dist.__name__, len(params))
|
||||
message = f'Failed for {Dist.__name__} example 0/{len(params)}'
|
||||
self.assertNotIn('logits', dist.__dict__, msg=message)
|
||||
try:
|
||||
dist.enumerate_support()
|
||||
@ -5161,7 +5161,7 @@ class TestJit(DistributionsTestCase):
|
||||
expected = f(sample, *values)
|
||||
actual = traced_f(sample, *values)
|
||||
self.assertEqual(expected, actual,
|
||||
msg='{}\nExpected:\n{}\nActual:\n{}'.format(Dist.__name__, expected, actual))
|
||||
msg=f'{Dist.__name__}\nExpected:\n{expected}\nActual:\n{actual}')
|
||||
|
||||
def test_enumerate_support(self):
|
||||
for Dist, keys, values, sample in self._examples():
|
||||
@ -5185,7 +5185,7 @@ class TestJit(DistributionsTestCase):
|
||||
expected = f(*values)
|
||||
actual = traced_f(*values)
|
||||
self.assertEqual(expected, actual,
|
||||
msg='{}\nExpected:\n{}\nActual:\n{}'.format(Dist.__name__, expected, actual))
|
||||
msg=f'{Dist.__name__}\nExpected:\n{expected}\nActual:\n{actual}')
|
||||
|
||||
def test_mean(self):
|
||||
for Dist, keys, values, sample in self._examples():
|
||||
@ -5207,7 +5207,7 @@ class TestJit(DistributionsTestCase):
|
||||
expected[expected == float('inf')] = 0.
|
||||
actual[actual == float('inf')] = 0.
|
||||
self.assertEqual(expected, actual,
|
||||
msg='{}\nExpected:\n{}\nActual:\n{}'.format(Dist.__name__, expected, actual))
|
||||
msg=f'{Dist.__name__}\nExpected:\n{expected}\nActual:\n{actual}')
|
||||
|
||||
def test_variance(self):
|
||||
for Dist, keys, values, sample in self._examples():
|
||||
@ -5231,7 +5231,7 @@ class TestJit(DistributionsTestCase):
|
||||
expected[expected == float('inf')] = 0.
|
||||
actual[actual == float('inf')] = 0.
|
||||
self.assertEqual(expected, actual,
|
||||
msg='{}\nExpected:\n{}\nActual:\n{}'.format(Dist.__name__, expected, actual))
|
||||
msg=f'{Dist.__name__}\nExpected:\n{expected}\nActual:\n{actual}')
|
||||
|
||||
def test_entropy(self):
|
||||
for Dist, keys, values, sample in self._examples():
|
||||
@ -5255,7 +5255,7 @@ class TestJit(DistributionsTestCase):
|
||||
expected = f(*values)
|
||||
actual = traced_f(*values)
|
||||
self.assertEqual(expected, actual,
|
||||
msg='{}\nExpected:\n{}\nActual:\n{}'.format(Dist.__name__, expected, actual))
|
||||
msg=f'{Dist.__name__}\nExpected:\n{expected}\nActual:\n{actual}')
|
||||
|
||||
def test_cdf(self):
|
||||
for Dist, keys, values, sample in self._examples():
|
||||
@ -5276,7 +5276,7 @@ class TestJit(DistributionsTestCase):
|
||||
expected = f(sample, *values)
|
||||
actual = traced_f(sample, *values)
|
||||
self.assertEqual(expected, actual,
|
||||
msg='{}\nExpected:\n{}\nActual:\n{}'.format(Dist.__name__, expected, actual))
|
||||
msg=f'{Dist.__name__}\nExpected:\n{expected}\nActual:\n{actual}')
|
||||
|
||||
|
||||
if __name__ == '__main__' and torch._C.has_lapack:
|
||||
|
@ -156,7 +156,7 @@ def generate_data(transform):
|
||||
x /= x.norm(dim=-1, keepdim=True)
|
||||
x.diagonal(dim1=-1).copy_(x.diagonal(dim1=-1).abs())
|
||||
return x
|
||||
raise ValueError('Unsupported domain: {}'.format(domain))
|
||||
raise ValueError(f'Unsupported domain: {domain}')
|
||||
|
||||
|
||||
TRANSFORMS_CACHE_ACTIVE = get_transforms(cache_size=1)
|
||||
@ -215,19 +215,19 @@ def test_forward_inverse(transform, test_cached):
|
||||
if transform.bijective:
|
||||
# verify function inverse
|
||||
assert torch.allclose(x2, x, atol=1e-4, equal_nan=True), '\n'.join([
|
||||
'{} t.inv(t(-)) error'.format(transform),
|
||||
'x = {}'.format(x),
|
||||
'y = t(x) = {}'.format(y),
|
||||
'x2 = t.inv(y) = {}'.format(x2),
|
||||
f'{transform} t.inv(t(-)) error',
|
||||
f'x = {x}',
|
||||
f'y = t(x) = {y}',
|
||||
f'x2 = t.inv(y) = {x2}',
|
||||
])
|
||||
else:
|
||||
# verify weaker function pseudo-inverse
|
||||
assert torch.allclose(y2, y, atol=1e-4, equal_nan=True), '\n'.join([
|
||||
'{} t(t.inv(t(-))) error'.format(transform),
|
||||
'x = {}'.format(x),
|
||||
'y = t(x) = {}'.format(y),
|
||||
'x2 = t.inv(y) = {}'.format(x2),
|
||||
'y2 = t(x2) = {}'.format(y2),
|
||||
f'{transform} t(t.inv(t(-))) error',
|
||||
f'x = {x}',
|
||||
f'y = t(x) = {y}',
|
||||
f'x2 = t.inv(y) = {x2}',
|
||||
f'y2 = t(x2) = {y2}',
|
||||
])
|
||||
|
||||
|
||||
|
@ -1701,8 +1701,8 @@ class TestOptim(TestCase):
|
||||
|
||||
num_tensors = 5
|
||||
for functional_optim, amsgrad, no_grad_scale in itertools.product((adam.adam, adamw.adamw), (False, True), (False, True)):
|
||||
params, grads, exp_avgs, exp_avg_sqs = [
|
||||
[torch.ones((1,), device="cuda") for _ in range(num_tensors)] for _ in range(4)]
|
||||
params, grads, exp_avgs, exp_avg_sqs = (
|
||||
[torch.ones((1,), device="cuda") for _ in range(num_tensors)] for _ in range(4))
|
||||
prev_params = [t.clone().detach() for t in params]
|
||||
max_exp_avg_sqs = [torch.ones((1,), device="cuda") for _ in range(num_tensors)] if amsgrad else []
|
||||
state_steps = [torch.ones((), dtype=torch.float32, device="cuda") for _ in range(num_tensors)]
|
||||
|
@ -258,7 +258,7 @@ class _IntegerInterval(Constraint):
|
||||
|
||||
def __repr__(self):
|
||||
fmt_string = self.__class__.__name__[1:]
|
||||
fmt_string += '(lower_bound={}, upper_bound={})'.format(self.lower_bound, self.upper_bound)
|
||||
fmt_string += f'(lower_bound={self.lower_bound}, upper_bound={self.upper_bound})'
|
||||
return fmt_string
|
||||
|
||||
|
||||
@ -277,7 +277,7 @@ class _IntegerLessThan(Constraint):
|
||||
|
||||
def __repr__(self):
|
||||
fmt_string = self.__class__.__name__[1:]
|
||||
fmt_string += '(upper_bound={})'.format(self.upper_bound)
|
||||
fmt_string += f'(upper_bound={self.upper_bound})'
|
||||
return fmt_string
|
||||
|
||||
|
||||
@ -296,7 +296,7 @@ class _IntegerGreaterThan(Constraint):
|
||||
|
||||
def __repr__(self):
|
||||
fmt_string = self.__class__.__name__[1:]
|
||||
fmt_string += '(lower_bound={})'.format(self.lower_bound)
|
||||
fmt_string += f'(lower_bound={self.lower_bound})'
|
||||
return fmt_string
|
||||
|
||||
|
||||
@ -321,7 +321,7 @@ class _GreaterThan(Constraint):
|
||||
|
||||
def __repr__(self):
|
||||
fmt_string = self.__class__.__name__[1:]
|
||||
fmt_string += '(lower_bound={})'.format(self.lower_bound)
|
||||
fmt_string += f'(lower_bound={self.lower_bound})'
|
||||
return fmt_string
|
||||
|
||||
|
||||
@ -338,7 +338,7 @@ class _GreaterThanEq(Constraint):
|
||||
|
||||
def __repr__(self):
|
||||
fmt_string = self.__class__.__name__[1:]
|
||||
fmt_string += '(lower_bound={})'.format(self.lower_bound)
|
||||
fmt_string += f'(lower_bound={self.lower_bound})'
|
||||
return fmt_string
|
||||
|
||||
|
||||
@ -355,7 +355,7 @@ class _LessThan(Constraint):
|
||||
|
||||
def __repr__(self):
|
||||
fmt_string = self.__class__.__name__[1:]
|
||||
fmt_string += '(upper_bound={})'.format(self.upper_bound)
|
||||
fmt_string += f'(upper_bound={self.upper_bound})'
|
||||
return fmt_string
|
||||
|
||||
|
||||
@ -373,7 +373,7 @@ class _Interval(Constraint):
|
||||
|
||||
def __repr__(self):
|
||||
fmt_string = self.__class__.__name__[1:]
|
||||
fmt_string += '(lower_bound={}, upper_bound={})'.format(self.lower_bound, self.upper_bound)
|
||||
fmt_string += f'(lower_bound={self.lower_bound}, upper_bound={self.upper_bound})'
|
||||
return fmt_string
|
||||
|
||||
|
||||
@ -391,7 +391,7 @@ class _HalfOpenInterval(Constraint):
|
||||
|
||||
def __repr__(self):
|
||||
fmt_string = self.__class__.__name__[1:]
|
||||
fmt_string += '(lower_bound={}, upper_bound={})'.format(self.lower_bound, self.upper_bound)
|
||||
fmt_string += f'(lower_bound={self.lower_bound}, upper_bound={self.upper_bound})'
|
||||
return fmt_string
|
||||
|
||||
|
||||
|
@ -109,4 +109,4 @@ class Independent(Distribution):
|
||||
return self.base_dist.enumerate_support(expand=expand)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + '({}, {})'.format(self.base_dist, self.reinterpreted_batch_ndims)
|
||||
return self.__class__.__name__ + f'({self.base_dist}, {self.reinterpreted_batch_ndims})'
|
||||
|
@ -65,9 +65,9 @@ def register_kl(type_p, type_q):
|
||||
type_q (type): A subclass of :class:`~torch.distributions.Distribution`.
|
||||
"""
|
||||
if not isinstance(type_p, type) and issubclass(type_p, Distribution):
|
||||
raise TypeError('Expected type_p to be a Distribution subclass but got {}'.format(type_p))
|
||||
raise TypeError(f'Expected type_p to be a Distribution subclass but got {type_p}')
|
||||
if not isinstance(type_q, type) and issubclass(type_q, Distribution):
|
||||
raise TypeError('Expected type_q to be a Distribution subclass but got {}'.format(type_q))
|
||||
raise TypeError(f'Expected type_q to be a Distribution subclass but got {type_q}')
|
||||
|
||||
def decorator(fun):
|
||||
_KL_REGISTRY[type_p, type_q] = fun
|
||||
@ -735,7 +735,7 @@ def _kl_uniform_beta(p, q):
|
||||
common_term = p.high - p.low
|
||||
t1 = torch.log(common_term)
|
||||
t2 = (q.concentration1 - 1) * (_x_log_x(p.high) - _x_log_x(p.low) - common_term) / common_term
|
||||
t3 = (q.concentration0 - 1) * (_x_log_x((1 - p.high)) - _x_log_x((1 - p.low)) + common_term) / common_term
|
||||
t3 = (q.concentration0 - 1) * (_x_log_x(1 - p.high) - _x_log_x(1 - p.low) + common_term) / common_term
|
||||
t4 = q.concentration1.lgamma() + q.concentration0.lgamma() - (q.concentration1 + q.concentration0).lgamma()
|
||||
result = t3 + t4 - t1 - t2
|
||||
result[(p.high > q.support.upper_bound) | (p.low < q.support.lower_bound)] = inf
|
||||
|
@ -93,7 +93,7 @@ class LowRankMultivariateNormal(Distribution):
|
||||
raise ValueError("cov_factor must be a batch of matrices with shape {} x m"
|
||||
.format(event_shape[0]))
|
||||
if cov_diag.shape[-1:] != event_shape:
|
||||
raise ValueError("cov_diag must be a batch of vectors with shape {}".format(event_shape))
|
||||
raise ValueError(f"cov_diag must be a batch of vectors with shape {event_shape}")
|
||||
|
||||
loc_ = loc.unsqueeze(-1)
|
||||
cov_diag_ = cov_diag.unsqueeze(-1)
|
||||
|
@ -71,17 +71,17 @@ class MixtureSameFamily(Distribution):
|
||||
cdbs = self._component_distribution.batch_shape[:-1]
|
||||
for size1, size2 in zip(reversed(mdbs), reversed(cdbs)):
|
||||
if size1 != 1 and size2 != 1 and size1 != size2:
|
||||
raise ValueError("`mixture_distribution.batch_shape` ({0}) is not "
|
||||
raise ValueError(f"`mixture_distribution.batch_shape` ({mdbs}) is not "
|
||||
"compatible with `component_distribution."
|
||||
"batch_shape`({1})".format(mdbs, cdbs))
|
||||
f"batch_shape`({cdbs})")
|
||||
|
||||
# Check that the number of mixture component matches
|
||||
km = self._mixture_distribution.logits.shape[-1]
|
||||
kc = self._component_distribution.batch_shape[-1]
|
||||
if km is not None and kc is not None and km != kc:
|
||||
raise ValueError("`mixture_distribution component` ({0}) does not"
|
||||
raise ValueError(f"`mixture_distribution component` ({km}) does not"
|
||||
" equal `component_distribution.batch_shape[-1]`"
|
||||
" ({1})".format(km, kc))
|
||||
f" ({kc})")
|
||||
self._num_component = km
|
||||
|
||||
event_shape = self._component_distribution.event_shape
|
||||
|
@ -51,7 +51,7 @@ class TransformedDistribution(Distribution):
|
||||
raise ValueError("transforms must be a Transform or a list of Transforms")
|
||||
self.transforms = transforms
|
||||
else:
|
||||
raise ValueError("transforms must be a Transform or list, but was {}".format(transforms))
|
||||
raise ValueError(f"transforms must be a Transform or list, but was {transforms}")
|
||||
|
||||
# Reshape base_distribution according to transforms.
|
||||
base_shape = base_distribution.batch_shape + base_distribution.event_shape
|
||||
|
@ -135,7 +135,7 @@ class Transform:
|
||||
return self
|
||||
if type(self).__init__ is Transform.__init__:
|
||||
return type(self)(cache_size=cache_size)
|
||||
raise NotImplementedError("{}.with_cache is not implemented".format(type(self)))
|
||||
raise NotImplementedError(f"{type(self)}.with_cache is not implemented")
|
||||
|
||||
def __eq__(self, other):
|
||||
return self is other
|
||||
@ -506,7 +506,7 @@ class ReshapeTransform(Transform):
|
||||
raise ValueError("Too few dimensions on input")
|
||||
cut = len(shape) - len(self.in_shape)
|
||||
if shape[cut:] != self.in_shape:
|
||||
raise ValueError("Shape mismatch: expected {} but got {}".format(shape[cut:], self.in_shape))
|
||||
raise ValueError(f"Shape mismatch: expected {shape[cut:]} but got {self.in_shape}")
|
||||
return shape[:cut] + self.out_shape
|
||||
|
||||
def inverse_shape(self, shape):
|
||||
@ -514,7 +514,7 @@ class ReshapeTransform(Transform):
|
||||
raise ValueError("Too few dimensions on input")
|
||||
cut = len(shape) - len(self.out_shape)
|
||||
if shape[cut:] != self.out_shape:
|
||||
raise ValueError("Shape mismatch: expected {} but got {}".format(shape[cut:], self.out_shape))
|
||||
raise ValueError(f"Shape mismatch: expected {shape[cut:]} but got {self.out_shape}")
|
||||
return shape[:cut] + self.in_shape
|
||||
|
||||
|
||||
|
@ -22,13 +22,13 @@ class Adadelta(Optimizer):
|
||||
differentiable: bool = False,
|
||||
):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
raise ValueError(f"Invalid learning rate: {lr}")
|
||||
if not 0.0 <= rho <= 1.0:
|
||||
raise ValueError("Invalid rho value: {}".format(rho))
|
||||
raise ValueError(f"Invalid rho value: {rho}")
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||
raise ValueError(f"Invalid epsilon value: {eps}")
|
||||
if not 0.0 <= weight_decay:
|
||||
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
||||
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
||||
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
|
@ -23,11 +23,11 @@ class Adagrad(Optimizer):
|
||||
differentiable: bool = False,
|
||||
):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
raise ValueError(f"Invalid learning rate: {lr}")
|
||||
if not 0.0 <= lr_decay:
|
||||
raise ValueError("Invalid lr_decay value: {}".format(lr_decay))
|
||||
raise ValueError(f"Invalid lr_decay value: {lr_decay}")
|
||||
if not 0.0 <= weight_decay:
|
||||
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
||||
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
||||
if not 0.0 <= initial_accumulator_value:
|
||||
raise ValueError(
|
||||
"Invalid initial_accumulator_value value: {}".format(
|
||||
@ -35,7 +35,7 @@ class Adagrad(Optimizer):
|
||||
)
|
||||
)
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||
raise ValueError(f"Invalid epsilon value: {eps}")
|
||||
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
|
@ -16,15 +16,15 @@ class Adam(Optimizer):
|
||||
maximize: bool = False, capturable: bool = False,
|
||||
differentiable: bool = False, fused: Optional[bool] = None):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
raise ValueError(f"Invalid learning rate: {lr}")
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||
raise ValueError(f"Invalid epsilon value: {eps}")
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
||||
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
||||
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
|
||||
if not 0.0 <= weight_decay:
|
||||
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
||||
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
||||
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps,
|
||||
weight_decay=weight_decay, amsgrad=amsgrad,
|
||||
|
@ -22,15 +22,15 @@ class Adamax(Optimizer):
|
||||
differentiable: bool = False,
|
||||
):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
raise ValueError(f"Invalid learning rate: {lr}")
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||
raise ValueError(f"Invalid epsilon value: {eps}")
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
||||
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
||||
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
|
||||
if not 0.0 <= weight_decay:
|
||||
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
||||
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
||||
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
|
@ -26,15 +26,15 @@ class AdamW(Optimizer):
|
||||
fused: Optional[bool] = None,
|
||||
):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
raise ValueError(f"Invalid learning rate: {lr}")
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||
raise ValueError(f"Invalid epsilon value: {eps}")
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
||||
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
||||
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
|
||||
if not 0.0 <= weight_decay:
|
||||
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
||||
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
betas=betas,
|
||||
|
@ -28,9 +28,9 @@ class ASGD(Optimizer):
|
||||
differentiable: bool = False,
|
||||
):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
raise ValueError(f"Invalid learning rate: {lr}")
|
||||
if not 0.0 <= weight_decay:
|
||||
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
||||
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
||||
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
|
@ -1366,11 +1366,11 @@ class CosineAnnealingWarmRestarts(LRScheduler):
|
||||
|
||||
def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1, verbose=False):
|
||||
if T_0 <= 0 or not isinstance(T_0, int):
|
||||
raise ValueError("Expected positive integer T_0, but got {}".format(T_0))
|
||||
raise ValueError(f"Expected positive integer T_0, but got {T_0}")
|
||||
if T_mult < 1 or not isinstance(T_mult, int):
|
||||
raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult))
|
||||
raise ValueError(f"Expected integer T_mult >= 1, but got {T_mult}")
|
||||
if not isinstance(eta_min, (float, int)):
|
||||
raise ValueError("Expected float or int eta_min, but got {} of type {}".format(eta_min, type(eta_min)))
|
||||
raise ValueError(f"Expected float or int eta_min, but got {eta_min} of type {type(eta_min)}")
|
||||
self.T_0 = T_0
|
||||
self.T_i = T_0
|
||||
self.T_mult = T_mult
|
||||
@ -1425,7 +1425,7 @@ class CosineAnnealingWarmRestarts(LRScheduler):
|
||||
self.T_i = self.T_i * self.T_mult
|
||||
else:
|
||||
if epoch < 0:
|
||||
raise ValueError("Expected non-negative epoch, but got {}".format(epoch))
|
||||
raise ValueError(f"Expected non-negative epoch, but got {epoch}")
|
||||
if epoch >= self.T_0:
|
||||
if self.T_mult == 1:
|
||||
self.T_cur = epoch % self.T_0
|
||||
@ -1590,13 +1590,13 @@ class OneCycleLR(LRScheduler):
|
||||
raise ValueError("You must define either total_steps OR (epochs AND steps_per_epoch)")
|
||||
elif total_steps is not None:
|
||||
if total_steps <= 0 or not isinstance(total_steps, int):
|
||||
raise ValueError("Expected positive integer total_steps, but got {}".format(total_steps))
|
||||
raise ValueError(f"Expected positive integer total_steps, but got {total_steps}")
|
||||
self.total_steps = total_steps
|
||||
else:
|
||||
if epochs <= 0 or not isinstance(epochs, int):
|
||||
raise ValueError("Expected positive integer epochs, but got {}".format(epochs))
|
||||
raise ValueError(f"Expected positive integer epochs, but got {epochs}")
|
||||
if steps_per_epoch <= 0 or not isinstance(steps_per_epoch, int):
|
||||
raise ValueError("Expected positive integer steps_per_epoch, but got {}".format(steps_per_epoch))
|
||||
raise ValueError(f"Expected positive integer steps_per_epoch, but got {steps_per_epoch}")
|
||||
self.total_steps = epochs * steps_per_epoch
|
||||
|
||||
if three_phase:
|
||||
@ -1643,11 +1643,11 @@ class OneCycleLR(LRScheduler):
|
||||
|
||||
# Validate pct_start
|
||||
if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
|
||||
raise ValueError("Expected float between 0 and 1 pct_start, but got {}".format(pct_start))
|
||||
raise ValueError(f"Expected float between 0 and 1 pct_start, but got {pct_start}")
|
||||
|
||||
# Validate anneal_strategy
|
||||
if anneal_strategy not in ['cos', 'linear']:
|
||||
raise ValueError("anneal_strategy must by one of 'cos' or 'linear', instead got {}".format(anneal_strategy))
|
||||
raise ValueError(f"anneal_strategy must by one of 'cos' or 'linear', instead got {anneal_strategy}")
|
||||
elif anneal_strategy == 'cos':
|
||||
self.anneal_func = self._annealing_cos
|
||||
elif anneal_strategy == 'linear':
|
||||
|
@ -11,17 +11,17 @@ class NAdam(Optimizer):
|
||||
weight_decay=0, momentum_decay=4e-3, *, foreach: Optional[bool] = None,
|
||||
differentiable: bool = False):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
raise ValueError(f"Invalid learning rate: {lr}")
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||
raise ValueError(f"Invalid epsilon value: {eps}")
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
||||
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
||||
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
|
||||
if not 0.0 <= weight_decay:
|
||||
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
||||
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
||||
if not 0.0 <= momentum_decay:
|
||||
raise ValueError("Invalid momentum_decay value: {}".format(momentum_decay))
|
||||
raise ValueError(f"Invalid momentum_decay value: {momentum_decay}")
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps,
|
||||
weight_decay=weight_decay, momentum_decay=momentum_decay,
|
||||
foreach=foreach, differentiable=differentiable)
|
||||
|
@ -246,10 +246,10 @@ class Optimizer:
|
||||
format_string = self.__class__.__name__ + ' ('
|
||||
for i, group in enumerate(self.param_groups):
|
||||
format_string += '\n'
|
||||
format_string += 'Parameter Group {0}\n'.format(i)
|
||||
format_string += f'Parameter Group {i}\n'
|
||||
for key in sorted(group.keys()):
|
||||
if key != 'params':
|
||||
format_string += ' {0}: {1}\n'.format(key, group[key])
|
||||
format_string += f' {key}: {group[key]}\n'
|
||||
format_string += ')'
|
||||
return format_string
|
||||
|
||||
@ -304,7 +304,7 @@ class Optimizer:
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
self, *_ = args
|
||||
profile_name = "Optimizer.step#{}.step".format(self.__class__.__name__)
|
||||
profile_name = f"Optimizer.step#{self.__class__.__name__}.step"
|
||||
with torch.autograd.profiler.record_function(profile_name):
|
||||
# call optimizer step pre hooks
|
||||
for pre_hook in chain(_global_optimizer_pre_hooks.values(), self._optimizer_step_pre_hooks.values()):
|
||||
@ -337,7 +337,7 @@ class Optimizer:
|
||||
return _group_tensors_by_device_and_dtype(tensorlistlist, with_indices)
|
||||
|
||||
def _patch_step_function(self):
|
||||
self._zero_grad_profile_name = "Optimizer.zero_grad#{}.zero_grad".format(self.__class__.__name__)
|
||||
self._zero_grad_profile_name = f"Optimizer.zero_grad#{self.__class__.__name__}.zero_grad"
|
||||
hooked = getattr(self.__class__.step, "hooked", None)
|
||||
if not hooked:
|
||||
self.__class__.step = self.profile_hook_step(self.__class__.step) # type: ignore[method-assign]
|
||||
@ -468,8 +468,8 @@ class Optimizer:
|
||||
"that doesn't match the size of optimizer's group")
|
||||
|
||||
# Update the state
|
||||
id_map = dict(zip(chain.from_iterable((g['params'] for g in saved_groups)),
|
||||
chain.from_iterable((g['params'] for g in groups))))
|
||||
id_map = dict(zip(chain.from_iterable(g['params'] for g in saved_groups),
|
||||
chain.from_iterable(g['params'] for g in groups)))
|
||||
|
||||
def cast(param, value, param_id=None, param_groups=None, key=None):
|
||||
r"""Make a deep copy of value, casting all tensors to device of param."""
|
||||
|
@ -22,15 +22,15 @@ class RAdam(Optimizer):
|
||||
differentiable: bool = False,
|
||||
):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
raise ValueError(f"Invalid learning rate: {lr}")
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||
raise ValueError(f"Invalid epsilon value: {eps}")
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
||||
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
||||
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
|
||||
if not 0.0 <= weight_decay:
|
||||
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
||||
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
betas=betas,
|
||||
|
@ -22,15 +22,15 @@ class RMSprop(Optimizer):
|
||||
differentiable: bool = False,
|
||||
):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
raise ValueError(f"Invalid learning rate: {lr}")
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||
raise ValueError(f"Invalid epsilon value: {eps}")
|
||||
if not 0.0 <= momentum:
|
||||
raise ValueError("Invalid momentum value: {}".format(momentum))
|
||||
raise ValueError(f"Invalid momentum value: {momentum}")
|
||||
if not 0.0 <= weight_decay:
|
||||
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
||||
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
||||
if not 0.0 <= alpha:
|
||||
raise ValueError("Invalid alpha value: {}".format(alpha))
|
||||
raise ValueError(f"Invalid alpha value: {alpha}")
|
||||
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
|
@ -20,9 +20,9 @@ class Rprop(Optimizer):
|
||||
differentiable: bool = False,
|
||||
):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
raise ValueError(f"Invalid learning rate: {lr}")
|
||||
if not 0.0 < etas[0] < 1.0 < etas[1]:
|
||||
raise ValueError("Invalid eta values: {}, {}".format(etas[0], etas[1]))
|
||||
raise ValueError(f"Invalid eta values: {etas[0]}, {etas[1]}")
|
||||
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
|
@ -11,11 +11,11 @@ class SGD(Optimizer):
|
||||
weight_decay=0, nesterov=False, *, maximize: bool = False, foreach: Optional[bool] = None,
|
||||
differentiable: bool = False):
|
||||
if lr is not required and lr < 0.0:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
raise ValueError(f"Invalid learning rate: {lr}")
|
||||
if momentum < 0.0:
|
||||
raise ValueError("Invalid momentum value: {}".format(momentum))
|
||||
raise ValueError(f"Invalid momentum value: {momentum}")
|
||||
if weight_decay < 0.0:
|
||||
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
||||
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
||||
|
||||
defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
|
||||
weight_decay=weight_decay, nesterov=nesterov,
|
||||
|
@ -7,13 +7,13 @@ __all__ = ['SparseAdam']
|
||||
class SparseAdam(Optimizer):
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, maximize: bool = False):
|
||||
if not 0.0 < lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
raise ValueError(f"Invalid learning rate: {lr}")
|
||||
if not 0.0 < eps:
|
||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||
raise ValueError(f"Invalid epsilon value: {eps}")
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
||||
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
||||
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
|
||||
|
||||
params = list(params)
|
||||
|
||||
|
@ -31,13 +31,13 @@ def _resolve_name(name, package, level):
|
||||
if len(bits) < level:
|
||||
raise ValueError("attempted relative import beyond top-level package")
|
||||
base = bits[0]
|
||||
return "{}.{}".format(base, name) if name else base
|
||||
return f"{base}.{name}" if name else base
|
||||
|
||||
|
||||
def _sanity_check(name, package, level):
|
||||
"""Verify arguments are "sane"."""
|
||||
if not isinstance(name, str):
|
||||
raise TypeError("module name must be str, not {}".format(type(name)))
|
||||
raise TypeError(f"module name must be str, not {type(name)}")
|
||||
if level < 0:
|
||||
raise ValueError("level must be >= 0")
|
||||
if level > 0:
|
||||
@ -90,6 +90,6 @@ def _normalize_path(path):
|
||||
"""
|
||||
parent, file_name = os.path.split(path)
|
||||
if parent:
|
||||
raise ValueError("{!r} must be only a file name".format(path))
|
||||
raise ValueError(f"{path!r} must be only a file name")
|
||||
else:
|
||||
return file_name
|
||||
|
@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import Dict, List
|
||||
|
||||
from .glob_group import GlobGroup, GlobPattern
|
||||
|
@ -79,7 +79,7 @@ class PackagingErrorReason(Enum):
|
||||
"""
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s.%s>" % (self.__class__.__name__, self.name)
|
||||
return f"<{self.__class__.__name__}.{self.name}>"
|
||||
|
||||
IS_EXTENSION_MODULE = (
|
||||
"Module is a C extension module. torch.package supports Python modules only."
|
||||
@ -156,14 +156,12 @@ class PackagingError(Exception):
|
||||
message.write(f" Context: {error_context}\n")
|
||||
if module_name in _DISALLOWED_MODULES:
|
||||
message.write(
|
||||
(
|
||||
" Note: While we usually use modules in the python standard library "
|
||||
f"from the local environment, `{module_name}` has a lot of system "
|
||||
"level access and therefore can pose a security risk. We heavily "
|
||||
f"recommend removing `{module_name}` from your packaged code. However, if that "
|
||||
"is not possible, add it to the extern list by calling "
|
||||
f'PackageExporter.extern("`{module_name}`")\n'
|
||||
)
|
||||
" Note: While we usually use modules in the python standard library "
|
||||
f"from the local environment, `{module_name}` has a lot of system "
|
||||
"level access and therefore can pose a security risk. We heavily "
|
||||
f"recommend removing `{module_name}` from your packaged code. However, if that "
|
||||
"is not possible, add it to the extern list by calling "
|
||||
f'PackageExporter.extern("`{module_name}`")\n'
|
||||
)
|
||||
if debug:
|
||||
module_path = dependency_graph.first_path(module_name)
|
||||
@ -173,10 +171,8 @@ class PackagingError(Exception):
|
||||
if not debug:
|
||||
message.write("\n")
|
||||
message.write(
|
||||
(
|
||||
"Set debug=True when invoking PackageExporter for a visualization of where "
|
||||
"broken modules are coming from!\n"
|
||||
)
|
||||
"Set debug=True when invoking PackageExporter for a visualization of where "
|
||||
"broken modules are coming from!\n"
|
||||
)
|
||||
# Save the dependency graph so that tooling can get at it.
|
||||
self.dependency_graph = dependency_graph
|
||||
|
@ -539,7 +539,7 @@ class PackageImporter(Importer):
|
||||
if not recursive and hasattr(module, "__all__"):
|
||||
self._handle_fromlist(module, module.__all__, recursive=True)
|
||||
elif not hasattr(module, x):
|
||||
from_name = "{}.{}".format(module_name, x)
|
||||
from_name = f"{module_name}.{x}"
|
||||
try:
|
||||
self._gcd_import(from_name)
|
||||
except ModuleNotFoundError as exc:
|
||||
@ -587,13 +587,13 @@ class PackageImporter(Importer):
|
||||
"""
|
||||
if hasattr(package, "__spec__"):
|
||||
if package.__spec__.submodule_search_locations is None:
|
||||
raise TypeError("{!r} is not a package".format(package.__spec__.name))
|
||||
raise TypeError(f"{package.__spec__.name!r} is not a package")
|
||||
else:
|
||||
return package
|
||||
else:
|
||||
module = self.import_module(package)
|
||||
if module.__spec__.submodule_search_locations is None:
|
||||
raise TypeError("{!r} is not a package".format(package))
|
||||
raise TypeError(f"{package!r} is not a package")
|
||||
else:
|
||||
return module
|
||||
|
||||
|
@ -738,11 +738,11 @@ class MemoryProfile:
|
||||
|
||||
for node in self._data_flow_graph.flow_nodes:
|
||||
all_tensor_versions.update(((k, v) for k, (_, v) in node.inputs.items()))
|
||||
all_tensor_versions.update(((key, 0) for key in node.intermediates))
|
||||
all_tensor_versions.update((key, 0) for key in node.intermediates)
|
||||
all_tensor_versions.update(node.outputs.items())
|
||||
|
||||
for i in self._categories._values.values():
|
||||
all_tensor_versions.update(((key, 0) for key in i._by_id_keyset))
|
||||
all_tensor_versions.update((key, 0) for key in i._by_id_keyset)
|
||||
|
||||
return {
|
||||
(key, version): self._categories.get(key, version)
|
||||
|
@ -642,7 +642,7 @@ def report_all_anti_patterns(prof,
|
||||
json_report_path = os.path.join(json_report_dir,
|
||||
"torchtidy_report.json")
|
||||
if os.path.exists(json_report_path):
|
||||
with open(json_report_path, "r") as f:
|
||||
with open(json_report_path) as f:
|
||||
exisiting_report = json.load(f)
|
||||
exisiting_report.update(report_dict)
|
||||
report_dict = exisiting_report
|
||||
|
@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import Optional, Iterable
|
||||
|
||||
import torch
|
||||
|
@ -136,28 +136,22 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
||||
# check device
|
||||
if not original_tensor.is_cuda:
|
||||
raise RuntimeError(
|
||||
(
|
||||
f"Error original_tensor.device= {original_tensor.device} is not supported! "
|
||||
"Only CUDA tensors are currently supported."
|
||||
)
|
||||
f"Error original_tensor.device= {original_tensor.device} is not supported! "
|
||||
"Only CUDA tensors are currently supported."
|
||||
)
|
||||
|
||||
# check dim
|
||||
if original_tensor.dim() != 2:
|
||||
raise RuntimeError(
|
||||
(
|
||||
f"Error original_tensor.dim = {original_tensor.dim()} is not supported! "
|
||||
"Only 2d tensors are currently supported."
|
||||
)
|
||||
f"Error original_tensor.dim = {original_tensor.dim()} is not supported! "
|
||||
"Only 2d tensors are currently supported."
|
||||
)
|
||||
|
||||
# check dtype
|
||||
if original_tensor.dtype not in _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG:
|
||||
raise RuntimeError(
|
||||
(
|
||||
f"Error original_tensor.dtype {original_tensor.dtype} is not a supported dtype! "
|
||||
"dtype must be one of: {_DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG}"
|
||||
)
|
||||
f"Error original_tensor.dtype {original_tensor.dtype} is not a supported dtype! "
|
||||
"dtype must be one of: {_DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG}"
|
||||
)
|
||||
|
||||
# check shape
|
||||
@ -167,10 +161,8 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
||||
if m < min_rows or m % min_rows or n < min_cols or n % min_cols:
|
||||
# TODO in the future we can add in padding to support dimensions that aren't perfect multiples
|
||||
raise RuntimeError(
|
||||
(
|
||||
f"Error original_tensor.shape {original_tensor.shape} is not supported! "
|
||||
"Both dimensions must be larger or equal than and a multiple of ({min_rows}, {min_cols})"
|
||||
)
|
||||
f"Error original_tensor.shape {original_tensor.shape} is not supported! "
|
||||
"Both dimensions must be larger or equal than and a multiple of ({min_rows}, {min_cols})"
|
||||
)
|
||||
|
||||
# This code calculates the size of the compressed tensor.
|
||||
|
Reference in New Issue
Block a user