mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-22 22:25:10 +08:00
Compare commits
15 Commits
cpp-docs-d
...
ciflow/tru
Author | SHA1 | Date | |
---|---|---|---|
b0bdd77ace | |||
b043f936b6 | |||
440c889bbe | |||
443a928e65 | |||
0e713330c8 | |||
f8ffa5edb7 | |||
6bf54e4cd7 | |||
a59dd53564 | |||
0c0b21c533 | |||
a299dca9a5 | |||
65b1fd617d | |||
51f414ddfc | |||
05822ad915 | |||
cf1dc28b2c | |||
c1be445be1 |
652
aten/src/ATen/native/LinearCrossEntropy.cpp
Normal file
652
aten/src/ATen/native/LinearCrossEntropy.cpp
Normal file
@ -0,0 +1,652 @@
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <ATen/TensorIterator.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <limits>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
#include <optional>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/empty.h>
|
||||
#include <ATen/ops/linear.h>
|
||||
#include <ATen/ops/cross_entropy_loss.h>
|
||||
#include <ATen/ops/zeros.h>
|
||||
#include <ATen/ops/zeros_like.h>
|
||||
#include <ATen/ops/full.h>
|
||||
#include <ATen/ops/ones.h>
|
||||
#include <ATen/ops/max.h>
|
||||
#include <ATen/ops/exp.h>
|
||||
#include <ATen/ops/log.h>
|
||||
#include <ATen/ops/logsumexp.h>
|
||||
#include <ATen/ops/where.h>
|
||||
#include <ATen/ops/ge.h>
|
||||
#include <ATen/ops/lt.h>
|
||||
#include <ATen/ops/logical_and.h>
|
||||
#include <ATen/ops/logical_or.h>
|
||||
#include <ATen/ops/logical_not.h>
|
||||
#include <ATen/ops/masked_fill.h>
|
||||
#include <ATen/ops/sub.h>
|
||||
#include <ATen/ops/add.h>
|
||||
#include <ATen/ops/mul.h>
|
||||
#include <ATen/ops/index_select.h>
|
||||
#include <ATen/ops/gather.h>
|
||||
#include <ATen/ops/nonzero.h>
|
||||
#include <ATen/ops/maximum.h>
|
||||
#include <ATen/ops/gt.h>
|
||||
#include <ATen/ops/div.h>
|
||||
#include <ATen/ops/ne.h>
|
||||
#include <ATen/ops/sum.h>
|
||||
#include <ATen/ops/linear_cross_entropy_backward_native.h>
|
||||
#include <ATen/ops/linear_cross_entropy_native.h>
|
||||
#endif
|
||||
|
||||
namespace at::native {
|
||||
|
||||
// Strategy selection for optimal chunking approach
|
||||
enum class ChunkingStrategy {
|
||||
NAIVE, // No chunking - standard approach
|
||||
VOCAB_CHUNKING, // Chunk vocabulary dimension (existing)
|
||||
BATCH_CHUNKING // Chunk batch dimension (new)
|
||||
};
|
||||
|
||||
Tensor batch_chunking_cpu(
|
||||
const Tensor& input,
|
||||
const Tensor& weight,
|
||||
const Tensor& target,
|
||||
const std::optional<Tensor>& bias_opt,
|
||||
int64_t reduction,
|
||||
int64_t ignore_index,
|
||||
double label_smoothing);
|
||||
|
||||
// Determine optimal chunking strategy based on input dimensions and user preference
|
||||
// Based on memory reduction analysis and empirical validation
|
||||
inline ChunkingStrategy select_chunking_strategy(
|
||||
int64_t vocab_size,
|
||||
int64_t flattened_batch_size,
|
||||
c10::string_view strategy) {
|
||||
|
||||
if (strategy == "none") {
|
||||
return ChunkingStrategy::NAIVE;
|
||||
} else if (strategy == "vocab") {
|
||||
return ChunkingStrategy::VOCAB_CHUNKING;
|
||||
} else if (strategy == "batch") {
|
||||
return ChunkingStrategy::BATCH_CHUNKING;
|
||||
} else if (strategy == "auto") {
|
||||
// Empirically validated chunk sizes for optimal memory/compute balance
|
||||
const int64_t vocab_chunk_size = 4096; // Same as existing implementation
|
||||
const int64_t batch_chunk_size = 1024; // Optimized for batch processing
|
||||
|
||||
const int64_t total_batch_size = flattened_batch_size;
|
||||
|
||||
// Determine which dimensions benefit from chunking
|
||||
bool vocab_large = vocab_size > vocab_chunk_size;
|
||||
bool batch_large = total_batch_size > batch_chunk_size;
|
||||
|
||||
if (!vocab_large && !batch_large) {
|
||||
return ChunkingStrategy::NAIVE;
|
||||
} else if (vocab_large && !batch_large) {
|
||||
return ChunkingStrategy::VOCAB_CHUNKING;
|
||||
} else if (!vocab_large && batch_large) {
|
||||
return ChunkingStrategy::BATCH_CHUNKING;
|
||||
} else {
|
||||
// Both dimensions are large - choose strategy with better memory reduction
|
||||
// Memory reduction = 1 - (chunk_size / total_size)
|
||||
double vocab_reduction = 1.0 - static_cast<double>(vocab_chunk_size) / vocab_size;
|
||||
double batch_reduction = 1.0 - static_cast<double>(batch_chunk_size) / total_batch_size;
|
||||
|
||||
return (vocab_reduction >= batch_reduction) ?
|
||||
ChunkingStrategy::VOCAB_CHUNKING : ChunkingStrategy::BATCH_CHUNKING;
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unknown chunking strategy: ", strategy,
|
||||
". Valid options: 'auto', 'vocab', 'batch', 'none'");
|
||||
}
|
||||
}
|
||||
|
||||
// Apply final reduction based on reduction mode
|
||||
// Handles mean/sum reduction consistently across all chunking strategies
|
||||
// Batch chunking implementation for CPU
|
||||
// Inspired by Liger Kernel approach: processes input in batch chunks to reduce memory usage
|
||||
// Memory reduction: [N, V] -> [chunk_size, V] where N = batch_size * seq_len
|
||||
Tensor batch_chunking_cpu(
|
||||
const Tensor& input,
|
||||
const Tensor& weight,
|
||||
const Tensor& target,
|
||||
const std::optional<Tensor>& bias_opt,
|
||||
int64_t reduction,
|
||||
int64_t ignore_index,
|
||||
double label_smoothing) {
|
||||
|
||||
// Flatten multi-dimensional inputs for processing (standard PyTorch pattern)
|
||||
// This allows handling both 2D [batch, hidden] and 3D [batch, seq, hidden] inputs
|
||||
auto input_flat = input.reshape({-1, input.size(-1)}); // [N, H] where N = batch * seq_len
|
||||
auto target_flat = target.reshape({-1}); // [N] flattened targets
|
||||
|
||||
const int64_t batch_size = input_flat.size(0);
|
||||
const int64_t chunk_size = 1024; // Empirically optimized for batch dimension chunking
|
||||
|
||||
// Get bias tensor if provided
|
||||
const Tensor& bias = bias_opt.value_or(Tensor());
|
||||
|
||||
// Early exit if batch is too small for chunking
|
||||
if (batch_size <= chunk_size) {
|
||||
auto logits = at::linear(input_flat, weight, bias);
|
||||
return at::cross_entropy_loss(logits, target_flat, std::nullopt, reduction, ignore_index, label_smoothing);
|
||||
}
|
||||
|
||||
const int64_t num_chunks = (batch_size + chunk_size - 1) / chunk_size;
|
||||
|
||||
Tensor losses_buffer;
|
||||
if (reduction == Reduction::None) {
|
||||
losses_buffer = at::zeros({batch_size}, input.options());
|
||||
}
|
||||
|
||||
Tensor total_loss = at::zeros({}, input.options());
|
||||
int64_t valid_count = 0;
|
||||
|
||||
// Process input in batch chunks to avoid materializing large logit tensors
|
||||
// Each chunk computes: [chunk_size, hidden] @ [hidden, vocab] -> [chunk_size, vocab]
|
||||
for (int64_t chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) {
|
||||
int64_t start_idx = chunk_idx * chunk_size;
|
||||
int64_t end_idx = std::min(start_idx + chunk_size, batch_size);
|
||||
|
||||
if (start_idx >= end_idx) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto input_chunk = input_flat.slice(0, start_idx, end_idx); // [actual_chunk_size, H]
|
||||
auto target_chunk = target_flat.slice(0, start_idx, end_idx); // [actual_chunk_size]
|
||||
auto logits_chunk = at::linear(input_chunk, weight, bias); // [actual_chunk_size, vocab_size]
|
||||
|
||||
auto valid_mask_chunk = at::ne(target_chunk, ignore_index);
|
||||
valid_count += valid_mask_chunk.sum().item<int64_t>();
|
||||
|
||||
const auto ce_reduction = (reduction == Reduction::None) ? Reduction::None : Reduction::Sum;
|
||||
auto chunk_loss = at::cross_entropy_loss(
|
||||
logits_chunk,
|
||||
target_chunk,
|
||||
std::nullopt,
|
||||
ce_reduction,
|
||||
ignore_index,
|
||||
label_smoothing);
|
||||
|
||||
if (reduction == Reduction::None) {
|
||||
auto dest = losses_buffer.slice(0, start_idx, end_idx);
|
||||
dest.copy_(chunk_loss);
|
||||
dest.masked_fill_(at::logical_not(valid_mask_chunk), 0);
|
||||
} else {
|
||||
total_loss = at::add(total_loss, chunk_loss);
|
||||
}
|
||||
}
|
||||
|
||||
if (reduction == Reduction::None) {
|
||||
return losses_buffer.reshape(target.sizes());
|
||||
}
|
||||
|
||||
if (reduction == Reduction::Sum || valid_count == 0) {
|
||||
return total_loss;
|
||||
}
|
||||
return at::div(total_loss, valid_count);
|
||||
}
|
||||
|
||||
Tensor linear_cross_entropy_cpu(
|
||||
const Tensor& input,
|
||||
const Tensor& weight,
|
||||
const Tensor& target,
|
||||
const std::optional<Tensor>& bias_opt,
|
||||
int64_t reduction,
|
||||
int64_t ignore_index,
|
||||
double label_smoothing,
|
||||
c10::string_view chunking_strategy) {
|
||||
|
||||
// Validate inputs
|
||||
TORCH_CHECK(input.dim() >= 2, "Expected input to have at least 2 dimensions, got ", input.dim());
|
||||
TORCH_CHECK(weight.dim() == 2, "Expected weight to be 2-dimensional, got ", weight.dim());
|
||||
TORCH_CHECK(input.size(-1) == weight.size(1),
|
||||
"Expected input.size(-1) to match weight.size(1), got ",
|
||||
input.size(-1), " and ", weight.size(1));
|
||||
|
||||
// Get bias tensor if provided
|
||||
const Tensor& bias = bias_opt.value_or(Tensor());
|
||||
|
||||
// Pick a chunking strategy that mirrors the Python wrapper so we only
|
||||
// materialise large logit tensors when it is worthwhile. Vocabulary chunking
|
||||
// slices the weight matrix (large vocabularies), batch chunking slices the
|
||||
// flattened batch (very large batches), and the naive path keeps the original
|
||||
// computation for small problems.
|
||||
|
||||
// Calculate input dimensions for strategy selection
|
||||
const int64_t vocab_size = weight.size(0);
|
||||
const int64_t flattened_batch = input.numel() / input.size(-1);
|
||||
|
||||
// Select optimal chunking strategy based on input characteristics and user preference
|
||||
ChunkingStrategy selected_strategy = select_chunking_strategy(vocab_size, flattened_batch, chunking_strategy);
|
||||
|
||||
auto input_flat = input.reshape({-1, input.size(-1)}); // [N, H]
|
||||
auto target_flat = target.reshape({-1}); // [N]
|
||||
auto valid_mask = at::ne(target_flat, ignore_index);
|
||||
|
||||
// Execute selected chunking strategy
|
||||
if (selected_strategy == ChunkingStrategy::VOCAB_CHUNKING) {
|
||||
const int64_t chunk_size = 4096; // Empirically validated chunk size for optimal memory/compute balance
|
||||
const int64_t num_chunks = (vocab_size + chunk_size - 1) / chunk_size;
|
||||
|
||||
const auto options = input_flat.options();
|
||||
auto long_options = options.dtype(at::kLong);
|
||||
const double neg_inf = -std::numeric_limits<double>::infinity();
|
||||
|
||||
Tensor running_max = at::full({input_flat.size(0)}, neg_inf, options);
|
||||
Tensor exp_sums = at::zeros({input_flat.size(0)}, options);
|
||||
Tensor target_logits = at::zeros({input_flat.size(0)}, options);
|
||||
Tensor target_found = at::zeros({input_flat.size(0)}, long_options);
|
||||
Tensor sum_logits;
|
||||
if (label_smoothing > 0.0) {
|
||||
sum_logits = at::zeros({input_flat.size(0)}, options);
|
||||
}
|
||||
|
||||
for (int64_t chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) {
|
||||
const int64_t start_idx = chunk_idx * chunk_size;
|
||||
const int64_t end_idx = std::min(start_idx + chunk_size, vocab_size);
|
||||
|
||||
auto weight_chunk = weight.slice(0, start_idx, end_idx); // [chunk, hidden]
|
||||
|
||||
std::optional<Tensor> bias_chunk;
|
||||
if (bias.defined()) {
|
||||
bias_chunk = bias.slice(0, start_idx, end_idx);
|
||||
}
|
||||
|
||||
auto logits_chunk = at::linear(input_flat, weight_chunk, bias_chunk); // [N, chunk]
|
||||
|
||||
if (label_smoothing > 0.0) {
|
||||
sum_logits = at::add(sum_logits, at::sum(logits_chunk, {-1}));
|
||||
}
|
||||
|
||||
auto chunk_max = std::get<0>(logits_chunk.max(-1));
|
||||
auto new_max = at::maximum(running_max, chunk_max);
|
||||
|
||||
auto exp_scale_old = at::exp(at::sub(running_max, new_max));
|
||||
auto shifted_logits = at::sub(logits_chunk, new_max.unsqueeze(-1));
|
||||
auto exp_chunk = at::sum(at::exp(shifted_logits), {-1});
|
||||
exp_sums = at::add(at::mul(exp_sums, exp_scale_old), exp_chunk);
|
||||
running_max = new_max;
|
||||
|
||||
auto lower_bound = at::ge(target_flat, start_idx);
|
||||
auto upper_bound = at::lt(target_flat, end_idx);
|
||||
auto target_chunk_mask = at::logical_and(valid_mask, lower_bound);
|
||||
target_chunk_mask = at::logical_and(target_chunk_mask, upper_bound);
|
||||
|
||||
auto indices = target_chunk_mask.nonzero().reshape({-1});
|
||||
if (indices.numel() > 0) {
|
||||
auto selected_targets = at::index_select(target_flat, 0, indices);
|
||||
auto local_targets = at::sub(selected_targets, start_idx);
|
||||
auto selected_logits = at::index_select(logits_chunk, 0, indices);
|
||||
auto gathered = at::gather(selected_logits, 1, local_targets.unsqueeze(1)).squeeze(1);
|
||||
target_logits.index_put_({indices}, gathered);
|
||||
auto ones_update = at::ones(indices.sizes(), long_options);
|
||||
target_found.index_put_({indices}, ones_update);
|
||||
}
|
||||
}
|
||||
|
||||
auto target_found_mask = target_found.gt(0);
|
||||
auto coverage_mask = at::logical_or(target_found_mask, at::logical_not(valid_mask));
|
||||
TORCH_CHECK(coverage_mask.all().item<bool>(),
|
||||
"linear_cross_entropy: target index not found in vocabulary chunks");
|
||||
|
||||
auto logsumexp = at::add(running_max, at::log(exp_sums));
|
||||
Tensor losses;
|
||||
if (label_smoothing > 0.0) {
|
||||
const double smoothing = label_smoothing;
|
||||
const double uniform = smoothing / static_cast<double>(vocab_size);
|
||||
auto main_term = at::mul(target_logits, 1.0 - smoothing);
|
||||
auto uniform_term = at::mul(sum_logits, uniform);
|
||||
losses = at::sub(logsumexp, main_term);
|
||||
losses = at::sub(losses, uniform_term);
|
||||
} else {
|
||||
losses = at::sub(logsumexp, target_logits);
|
||||
}
|
||||
|
||||
auto invalid_mask = at::logical_not(valid_mask);
|
||||
losses.masked_fill_(invalid_mask, 0);
|
||||
|
||||
if (reduction == Reduction::None) {
|
||||
return losses.reshape(target.sizes());
|
||||
}
|
||||
|
||||
auto total_loss = losses.sum();
|
||||
if (reduction == Reduction::Sum) {
|
||||
return total_loss;
|
||||
}
|
||||
|
||||
const int64_t valid_count = valid_mask.sum().item<int64_t>();
|
||||
if (valid_count == 0) {
|
||||
return total_loss; // Match cross_entropy behaviour when all targets ignored
|
||||
}
|
||||
return at::div(total_loss, valid_count);
|
||||
|
||||
} else if (selected_strategy == ChunkingStrategy::BATCH_CHUNKING) {
|
||||
// Batch chunking implementation - call dedicated function
|
||||
return batch_chunking_cpu(input, weight, target, bias_opt, reduction, ignore_index, label_smoothing);
|
||||
|
||||
} else { // ChunkingStrategy::NAIVE
|
||||
// Naive implementation for small models or when chunking not beneficial
|
||||
auto logits = at::linear(input, weight, bias);
|
||||
auto logits_flat = logits.reshape({-1, logits.size(-1)});
|
||||
return at::cross_entropy_loss(logits_flat, target_flat, std::nullopt, reduction, ignore_index, label_smoothing);
|
||||
}
|
||||
}
|
||||
|
||||
// The backward implementation mirrors the forward chunking strategy so we never
|
||||
// materialize a full [N, vocab] tensor while reconstructing gradients. The
|
||||
// helpers below break the work into vocabulary-chunked and batch-chunked paths
|
||||
// and rely exclusively on ATen operators so that the code remains device agnostic
|
||||
// and benefits from existing BLAS/cuBLAS bindings.
|
||||
namespace {
|
||||
|
||||
inline Tensor zeros_like_tensor(const Tensor& src) {
|
||||
return at::_ops::zeros_like::call(src, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt);
|
||||
}
|
||||
|
||||
inline Tensor zeros_like_or_undef(const std::optional<Tensor>& opt) {
|
||||
if (opt.has_value()) {
|
||||
return zeros_like_tensor(opt.value());
|
||||
}
|
||||
return Tensor();
|
||||
}
|
||||
|
||||
inline Tensor cast_grad_output(const Tensor& grad_output, const Tensor& input) {
|
||||
return grad_output.to(input.scalar_type());
|
||||
}
|
||||
|
||||
inline Tensor mask_invalid_rows(const Tensor& tensor, const Tensor& valid_mask) {
|
||||
auto mask = valid_mask.to(tensor.scalar_type()).unsqueeze(1);
|
||||
return at::mul(tensor, mask);
|
||||
}
|
||||
|
||||
inline void apply_target_updates(
|
||||
Tensor& grad_chunk,
|
||||
const Tensor& target_flat,
|
||||
const Tensor& rows,
|
||||
int64_t offset,
|
||||
double label_smoothing) {
|
||||
if (rows.numel() == 0) {
|
||||
return;
|
||||
}
|
||||
auto selected_targets = at::index_select(target_flat, 0, rows);
|
||||
auto local_targets = selected_targets.add(-offset).to(at::kLong);
|
||||
auto gather = grad_chunk.index({rows, local_targets}).add(-(1.0 - label_smoothing));
|
||||
grad_chunk.index_put_({rows, local_targets}, gather);
|
||||
}
|
||||
|
||||
inline void scale_grad_chunk(
|
||||
Tensor& grad_chunk,
|
||||
const Tensor& grad_output_tensor,
|
||||
const Tensor& grad_output_flat,
|
||||
int64_t reduction,
|
||||
int64_t valid_count) {
|
||||
if (reduction == Reduction::None) {
|
||||
grad_chunk.mul_(grad_output_flat.unsqueeze(1));
|
||||
return;
|
||||
}
|
||||
if (reduction == Reduction::Sum) {
|
||||
grad_chunk.mul_(grad_output_tensor);
|
||||
return;
|
||||
}
|
||||
TORCH_CHECK(valid_count >= 0, "Valid element count must be non-negative");
|
||||
if (valid_count == 0) {
|
||||
grad_chunk.zero_();
|
||||
return;
|
||||
}
|
||||
auto scale = grad_output_tensor.div(static_cast<double>(valid_count));
|
||||
grad_chunk.mul_(scale);
|
||||
}
|
||||
|
||||
// Computes gradients when the forward pass chose vocabulary chunking. We run a
|
||||
// first pass to rebuild the per-sample logsumexp using the same streaming scheme
|
||||
// as the forward kernel, then revisit each chunk to accumulate gradients for the
|
||||
// input, weight and (optional) bias tensors.
|
||||
inline std::tuple<Tensor, Tensor, std::optional<Tensor>> backward_vocabulary_chunking(
|
||||
const Tensor& input,
|
||||
const Tensor& weight,
|
||||
const Tensor& target,
|
||||
const std::optional<Tensor>& bias_opt,
|
||||
const Tensor& grad_output,
|
||||
int64_t reduction,
|
||||
int64_t ignore_index,
|
||||
double label_smoothing,
|
||||
ChunkingStrategy resolved_strategy) {
|
||||
const auto input_flat = input.reshape({-1, input.size(-1)});
|
||||
const auto target_flat = target.reshape({-1});
|
||||
const auto dtype = input.scalar_type();
|
||||
const auto options = input.options();
|
||||
Tensor valid_mask = at::ne(target_flat, ignore_index);
|
||||
const int64_t valid_count = valid_mask.sum().item<int64_t>();
|
||||
|
||||
if (reduction == Reduction::Mean && valid_count == 0) {
|
||||
Tensor grad_input = zeros_like_tensor(input);
|
||||
Tensor grad_weight = zeros_like_tensor(weight);
|
||||
Tensor grad_bias = zeros_like_or_undef(bias_opt);
|
||||
std::optional<Tensor> grad_bias_opt;
|
||||
if (grad_bias.defined()) {
|
||||
grad_bias_opt = std::move(grad_bias);
|
||||
}
|
||||
return std::make_tuple(std::move(grad_input), std::move(grad_weight), std::move(grad_bias_opt));
|
||||
}
|
||||
|
||||
const int64_t vocab_size = weight.size(0);
|
||||
const int64_t chunk_size = 4096;
|
||||
const int64_t num_chunks = (vocab_size + chunk_size - 1) / chunk_size;
|
||||
|
||||
Tensor running_max = at::full({input_flat.size(0)}, -std::numeric_limits<double>::infinity(), options).to(dtype);
|
||||
Tensor exp_sums = at::zeros({input_flat.size(0)}, options).to(dtype);
|
||||
|
||||
for (int64_t chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) {
|
||||
const int64_t start_idx = chunk_idx * chunk_size;
|
||||
const int64_t end_idx = std::min(start_idx + chunk_size, vocab_size);
|
||||
auto weight_chunk = weight.slice(0, start_idx, end_idx);
|
||||
std::optional<Tensor> bias_chunk;
|
||||
if (bias_opt.has_value()) {
|
||||
bias_chunk = bias_opt->slice(0, start_idx, end_idx);
|
||||
}
|
||||
auto logits_chunk = at::linear(input_flat, weight_chunk, bias_chunk);
|
||||
auto chunk_max = std::get<0>(logits_chunk.max(-1));
|
||||
auto new_max = at::maximum(running_max, chunk_max);
|
||||
auto exp_scale_old = at::exp(running_max.sub(new_max));
|
||||
auto shifted_logits = logits_chunk.sub(new_max.unsqueeze(-1));
|
||||
auto exp_chunk = at::sum(at::exp(shifted_logits), {-1});
|
||||
exp_sums = at::add(at::mul(exp_sums, exp_scale_old), exp_chunk);
|
||||
running_max = new_max;
|
||||
}
|
||||
|
||||
Tensor logsumexp = running_max.add(exp_sums.log());
|
||||
Tensor grad_input = zeros_like_tensor(input_flat);
|
||||
Tensor grad_weight = zeros_like_tensor(weight);
|
||||
Tensor grad_bias = zeros_like_or_undef(bias_opt);
|
||||
|
||||
const double uniform_component = label_smoothing > 0.0 ? label_smoothing / static_cast<double>(vocab_size) : 0.0;
|
||||
Tensor grad_output_tensor = cast_grad_output(grad_output, input);
|
||||
Tensor grad_output_flat;
|
||||
if (reduction == Reduction::None) {
|
||||
grad_output_flat = grad_output_tensor.reshape(-1);
|
||||
}
|
||||
|
||||
for (int64_t chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) {
|
||||
const int64_t start_idx = chunk_idx * chunk_size;
|
||||
const int64_t end_idx = std::min(start_idx + chunk_size, vocab_size);
|
||||
auto weight_chunk = weight.slice(0, start_idx, end_idx);
|
||||
std::optional<Tensor> bias_chunk;
|
||||
if (bias_opt.has_value()) {
|
||||
bias_chunk = bias_opt->slice(0, start_idx, end_idx);
|
||||
}
|
||||
auto logits_chunk = at::linear(input_flat, weight_chunk, bias_chunk);
|
||||
auto grad_chunk = at::exp(logits_chunk.sub(logsumexp.unsqueeze(-1)));
|
||||
if (label_smoothing > 0.0) {
|
||||
grad_chunk = grad_chunk.add(-uniform_component);
|
||||
}
|
||||
grad_chunk = mask_invalid_rows(grad_chunk, valid_mask);
|
||||
auto lower_bound = target_flat.ge(start_idx);
|
||||
auto upper_bound = target_flat.lt(end_idx);
|
||||
auto target_chunk_mask = at::logical_and(valid_mask, lower_bound);
|
||||
target_chunk_mask = at::logical_and(target_chunk_mask, upper_bound);
|
||||
auto rows = target_chunk_mask.nonzero().squeeze(-1);
|
||||
apply_target_updates(grad_chunk, target_flat, rows, start_idx, label_smoothing);
|
||||
scale_grad_chunk(grad_chunk, grad_output_tensor, grad_output_flat, reduction, valid_count);
|
||||
grad_chunk = mask_invalid_rows(grad_chunk, valid_mask);
|
||||
grad_input.add_(grad_chunk.matmul(weight_chunk));
|
||||
grad_weight.slice(0, start_idx, end_idx).add_(grad_chunk.transpose(0, 1).matmul(input_flat));
|
||||
if (grad_bias.defined()) {
|
||||
grad_bias.slice(0, start_idx, end_idx).add_(grad_chunk.sum(0));
|
||||
}
|
||||
}
|
||||
|
||||
grad_input = grad_input.reshape_as(input);
|
||||
std::optional<Tensor> grad_bias_opt;
|
||||
if (grad_bias.defined()) {
|
||||
grad_bias_opt = std::move(grad_bias);
|
||||
}
|
||||
return std::make_tuple(std::move(grad_input), std::move(grad_weight), std::move(grad_bias_opt));
|
||||
}
|
||||
|
||||
// Computes gradients when we chunked the batch dimension (or not at all). The
|
||||
// loop keeps the working set bounded by `chunk_size` rows so that we never
|
||||
// allocate a full [N, vocab] buffer even when the batch is very large.
|
||||
inline std::tuple<Tensor, Tensor, std::optional<Tensor>> backward_batch_chunking(
|
||||
const Tensor& input,
|
||||
const Tensor& weight,
|
||||
const Tensor& target,
|
||||
const std::optional<Tensor>& bias_opt,
|
||||
const Tensor& grad_output,
|
||||
int64_t reduction,
|
||||
int64_t ignore_index,
|
||||
double label_smoothing,
|
||||
int64_t chunk_size) {
|
||||
const auto input_flat = input.reshape({-1, input.size(-1)});
|
||||
const auto target_flat = target.reshape({-1});
|
||||
Tensor valid_mask = at::ne(target_flat, ignore_index);
|
||||
const int64_t valid_count = valid_mask.sum().item<int64_t>();
|
||||
|
||||
if (reduction == Reduction::Mean && valid_count == 0) {
|
||||
Tensor grad_input = zeros_like_tensor(input);
|
||||
Tensor grad_weight = zeros_like_tensor(weight);
|
||||
Tensor grad_bias = zeros_like_or_undef(bias_opt);
|
||||
std::optional<Tensor> grad_bias_opt;
|
||||
if (grad_bias.defined()) {
|
||||
grad_bias_opt = std::move(grad_bias);
|
||||
}
|
||||
return std::make_tuple(std::move(grad_input), std::move(grad_weight), std::move(grad_bias_opt));
|
||||
}
|
||||
|
||||
Tensor grad_input = zeros_like_tensor(input_flat);
|
||||
Tensor grad_weight = zeros_like_tensor(weight);
|
||||
Tensor grad_bias = zeros_like_or_undef(bias_opt);
|
||||
|
||||
Tensor grad_output_tensor = cast_grad_output(grad_output, input);
|
||||
Tensor grad_output_flat;
|
||||
if (reduction == Reduction::None) {
|
||||
grad_output_flat = grad_output_tensor.reshape(-1);
|
||||
}
|
||||
|
||||
const double uniform_component = label_smoothing > 0.0 ? label_smoothing / static_cast<double>(weight.size(0)) : 0.0;
|
||||
const int64_t total = input_flat.size(0);
|
||||
|
||||
for (int64_t start_idx = 0; start_idx < total; start_idx += chunk_size) {
|
||||
const int64_t slice = std::min<int64_t>(chunk_size, total - start_idx);
|
||||
auto input_chunk = input_flat.narrow(0, start_idx, slice);
|
||||
auto target_chunk = target_flat.narrow(0, start_idx, slice);
|
||||
auto valid_mask_chunk = valid_mask.narrow(0, start_idx, slice);
|
||||
auto logits_chunk = at::linear(input_chunk, weight, bias_opt);
|
||||
auto logsumexp_chunk = at::_ops::logsumexp::call(logits_chunk, std::vector<int64_t>{1}, false);
|
||||
auto grad_chunk = at::exp(logits_chunk.sub(logsumexp_chunk.unsqueeze(-1)));
|
||||
if (label_smoothing > 0.0) {
|
||||
grad_chunk = grad_chunk.add(-uniform_component);
|
||||
}
|
||||
grad_chunk = mask_invalid_rows(grad_chunk, valid_mask_chunk);
|
||||
auto rows = valid_mask_chunk.nonzero().squeeze(-1);
|
||||
if (rows.numel() > 0) {
|
||||
auto targets_slice = at::index_select(target_chunk, 0, rows).to(at::kLong);
|
||||
auto gather = grad_chunk.index({rows, targets_slice}).add(-(1.0 - label_smoothing));
|
||||
grad_chunk.index_put_({rows, targets_slice}, gather);
|
||||
}
|
||||
Tensor grad_scale = grad_output_tensor;
|
||||
if (reduction == Reduction::None) {
|
||||
grad_chunk.mul_(grad_output_flat.narrow(0, start_idx, slice).unsqueeze(1));
|
||||
} else if (reduction == Reduction::Sum) {
|
||||
grad_chunk.mul_(grad_scale);
|
||||
} else {
|
||||
if (valid_count == 0) {
|
||||
continue;
|
||||
}
|
||||
grad_chunk.mul_(grad_scale.div(static_cast<double>(valid_count)));
|
||||
}
|
||||
grad_chunk = mask_invalid_rows(grad_chunk, valid_mask_chunk);
|
||||
grad_input.narrow(0, start_idx, slice).add_(grad_chunk.matmul(weight));
|
||||
grad_weight.add_(grad_chunk.transpose(0, 1).matmul(input_chunk));
|
||||
if (grad_bias.defined()) {
|
||||
grad_bias.add_(grad_chunk.sum(0));
|
||||
}
|
||||
}
|
||||
|
||||
grad_input = grad_input.reshape_as(input);
|
||||
std::optional<Tensor> grad_bias_opt;
|
||||
if (grad_bias.defined()) {
|
||||
grad_bias_opt = std::move(grad_bias);
|
||||
}
|
||||
return std::make_tuple(std::move(grad_input), std::move(grad_weight), std::move(grad_bias_opt));
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
std::tuple<Tensor, Tensor, std::optional<Tensor>> linear_cross_entropy_backward_cpu(
|
||||
const Tensor& grad_output,
|
||||
const Tensor& input,
|
||||
const Tensor& weight,
|
||||
const Tensor& target,
|
||||
const std::optional<Tensor>& bias_opt,
|
||||
int64_t reduction,
|
||||
int64_t ignore_index,
|
||||
double label_smoothing,
|
||||
c10::string_view chunking_strategy) {
|
||||
|
||||
TORCH_CHECK(input.dim() >= 2, "Expected input to have at least 2 dimensions, got ", input.dim());
|
||||
TORCH_CHECK(weight.dim() == 2, "Expected weight to be 2-dimensional, got ", weight.dim());
|
||||
TORCH_CHECK(input.size(-1) == weight.size(1),
|
||||
"Expected input.size(-1) to match weight.size(1), got ",
|
||||
input.size(-1), " and ", weight.size(1));
|
||||
TORCH_CHECK(target.device() == input.device(), "Target must be on the same device as input");
|
||||
|
||||
const int64_t vocab_size = weight.size(0);
|
||||
const int64_t flattened_batch = input.numel() / input.size(-1);
|
||||
ChunkingStrategy resolved_strategy = select_chunking_strategy(vocab_size, flattened_batch, chunking_strategy);
|
||||
|
||||
if (resolved_strategy == ChunkingStrategy::VOCAB_CHUNKING) {
|
||||
return backward_vocabulary_chunking(
|
||||
input,
|
||||
weight,
|
||||
target,
|
||||
bias_opt,
|
||||
grad_output,
|
||||
reduction,
|
||||
ignore_index,
|
||||
label_smoothing,
|
||||
resolved_strategy);
|
||||
}
|
||||
|
||||
const int64_t default_chunk = resolved_strategy == ChunkingStrategy::BATCH_CHUNKING ? 1024 : flattened_batch;
|
||||
return backward_batch_chunking(
|
||||
input,
|
||||
weight,
|
||||
target,
|
||||
bias_opt,
|
||||
grad_output,
|
||||
reduction,
|
||||
ignore_index,
|
||||
label_smoothing,
|
||||
default_chunk);
|
||||
}
|
||||
|
||||
} // namespace at::native
|
@ -9488,6 +9488,14 @@
|
||||
dispatch:
|
||||
CompositeImplicitAutograd: cross_entropy_loss_symint
|
||||
|
||||
- func: linear_cross_entropy(Tensor input, Tensor weight, Tensor target, Tensor? bias=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0, str chunking_strategy="auto") -> Tensor
|
||||
dispatch:
|
||||
CPU: linear_cross_entropy_cpu
|
||||
|
||||
- func: linear_cross_entropy_backward(Tensor grad_output, Tensor input, Tensor weight, Tensor target, Tensor? bias=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0, str chunking_strategy="auto") -> (Tensor, Tensor, Tensor?)
|
||||
dispatch:
|
||||
CPU: linear_cross_entropy_backward_cpu
|
||||
|
||||
- func: triangular_solve.X(Tensor self, Tensor A, bool upper=True, bool transpose=False, bool unitriangular=False, *, Tensor(a!) X, Tensor(b!) M) -> (Tensor(a!) solution, Tensor(b!) cloned_coefficient)
|
||||
structured: True
|
||||
dispatch:
|
||||
|
@ -314,6 +314,8 @@ aten::lift.out
|
||||
aten::lift_fresh
|
||||
aten::linalg_vector_norm
|
||||
aten::linalg_vector_norm.out
|
||||
aten::linear_cross_entropy
|
||||
aten::linear_cross_entropy_backward
|
||||
aten::log
|
||||
aten::log.out
|
||||
aten::log10
|
||||
|
256
test/nn/test_linear_cross_entropy.py
Normal file
256
test/nn/test_linear_cross_entropy.py
Normal file
@ -0,0 +1,256 @@
|
||||
# Owner(s): ["module: nn"]
|
||||
|
||||
import itertools
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
def _reference_linear_cross_entropy(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
bias: Optional[torch.Tensor],
|
||||
reduction: str,
|
||||
ignore_index: int,
|
||||
label_smoothing: float,
|
||||
) -> torch.Tensor:
|
||||
logits = F.linear(input, weight, bias)
|
||||
logits_flat = logits.reshape(-1, logits.size(-1))
|
||||
target_flat = target.reshape(-1)
|
||||
loss = F.cross_entropy(
|
||||
logits_flat,
|
||||
target_flat,
|
||||
reduction=reduction,
|
||||
ignore_index=ignore_index,
|
||||
label_smoothing=label_smoothing,
|
||||
)
|
||||
if reduction == "none":
|
||||
loss = loss.reshape(target.shape)
|
||||
return loss
|
||||
|
||||
|
||||
class TestLinearCrossEntropyCPU(TestCase):
|
||||
def _compare_with_reference(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
bias: Optional[torch.Tensor],
|
||||
*,
|
||||
reduction: str = "mean",
|
||||
ignore_index: int = -100,
|
||||
label_smoothing: float = 0.0,
|
||||
chunking_strategy: str = "auto",
|
||||
) -> None:
|
||||
input_clone = input.clone(memory_format=torch.preserve_format).requires_grad_(
|
||||
input.requires_grad
|
||||
)
|
||||
weight_clone = weight.clone(memory_format=torch.preserve_format).requires_grad_(
|
||||
weight.requires_grad
|
||||
)
|
||||
bias_clone = None
|
||||
if bias is not None:
|
||||
bias_clone = bias.clone(memory_format=torch.preserve_format).requires_grad_(
|
||||
bias.requires_grad
|
||||
)
|
||||
|
||||
fused = F.linear_cross_entropy(
|
||||
input,
|
||||
weight,
|
||||
target,
|
||||
bias,
|
||||
reduction=reduction,
|
||||
ignore_index=ignore_index,
|
||||
label_smoothing=label_smoothing,
|
||||
chunking_strategy=chunking_strategy,
|
||||
)
|
||||
ref = _reference_linear_cross_entropy(
|
||||
input_clone,
|
||||
weight_clone,
|
||||
target,
|
||||
bias_clone,
|
||||
reduction=reduction,
|
||||
ignore_index=ignore_index,
|
||||
label_smoothing=label_smoothing,
|
||||
)
|
||||
|
||||
if fused.requires_grad:
|
||||
grad_args = [
|
||||
tensor for tensor in (input, weight, bias) if tensor is not None
|
||||
]
|
||||
grad_args_ref = [
|
||||
tensor
|
||||
for tensor in (input_clone, weight_clone, bias_clone)
|
||||
if tensor is not None
|
||||
]
|
||||
|
||||
if reduction == "none":
|
||||
grad_output = torch.ones_like(fused)
|
||||
fused_grads = torch.autograd.grad(
|
||||
fused,
|
||||
grad_args,
|
||||
grad_outputs=grad_output,
|
||||
retain_graph=False,
|
||||
allow_unused=True,
|
||||
)
|
||||
ref_grads = torch.autograd.grad(
|
||||
ref,
|
||||
grad_args_ref,
|
||||
grad_outputs=grad_output,
|
||||
retain_graph=False,
|
||||
allow_unused=True,
|
||||
)
|
||||
else:
|
||||
fused_grads = torch.autograd.grad(
|
||||
fused,
|
||||
grad_args,
|
||||
retain_graph=False,
|
||||
allow_unused=True,
|
||||
)
|
||||
ref_grads = torch.autograd.grad(
|
||||
ref,
|
||||
grad_args_ref,
|
||||
retain_graph=False,
|
||||
allow_unused=True,
|
||||
)
|
||||
|
||||
for grad_fused, grad_ref, tensor in zip(fused_grads, ref_grads, grad_args):
|
||||
if grad_fused is None or grad_ref is None:
|
||||
self.assertTrue(grad_fused is None and grad_ref is None)
|
||||
else:
|
||||
self.assertEqual(grad_fused, grad_ref)
|
||||
|
||||
if reduction == "none":
|
||||
self.assertEqual(fused.shape, target.shape)
|
||||
self.assertEqual(ref.shape, target.shape)
|
||||
self.assertEqual(fused, ref)
|
||||
|
||||
def test_forward_backward_matches_reference_auto(self) -> None:
|
||||
torch.manual_seed(0)
|
||||
input = torch.randn(2, 3, 32, requires_grad=True)
|
||||
weight = torch.randn(6000, 32, requires_grad=True)
|
||||
bias = torch.randn(6000, requires_grad=True)
|
||||
target = torch.randint(0, 6000, (2, 3))
|
||||
self._compare_with_reference(
|
||||
input, weight, target, bias, chunking_strategy="auto"
|
||||
)
|
||||
|
||||
def test_vocab_chunking(self) -> None:
|
||||
torch.manual_seed(0)
|
||||
input = torch.randn(4, 16, requires_grad=True)
|
||||
weight = torch.randn(5000, 16, requires_grad=True)
|
||||
target = torch.randint(0, 5000, (4,))
|
||||
self._compare_with_reference(
|
||||
input, weight, target, None, chunking_strategy="vocab"
|
||||
)
|
||||
|
||||
def test_batch_chunking(self) -> None:
|
||||
torch.manual_seed(0)
|
||||
input = torch.randn(1500, 8, requires_grad=True)
|
||||
weight = torch.randn(64, 8, requires_grad=True)
|
||||
target = torch.randint(0, 64, (1500,))
|
||||
self._compare_with_reference(
|
||||
input, weight, target, None, chunking_strategy="batch"
|
||||
)
|
||||
|
||||
def test_non_contiguous_inputs(self) -> None:
|
||||
torch.manual_seed(0)
|
||||
base = torch.randn(2, 3, 5)
|
||||
input_nc = base.transpose(0, 1)
|
||||
weight = torch.randn(4, 5)
|
||||
target_base = torch.randint(0, 4, (2, 3), dtype=torch.long)
|
||||
target_nc = target_base.transpose(0, 1)
|
||||
self._compare_with_reference(
|
||||
input_nc,
|
||||
weight,
|
||||
target_nc,
|
||||
None,
|
||||
chunking_strategy="batch",
|
||||
)
|
||||
|
||||
def test_auto_chunking_high_rank(self) -> None:
|
||||
torch.manual_seed(0)
|
||||
input = torch.randn(2, 3, 4, 5, 6)
|
||||
weight = torch.randn(8, 6)
|
||||
target = torch.randint(0, 8, (2, 3, 4, 5))
|
||||
self._compare_with_reference(
|
||||
input, weight, target, None, chunking_strategy="auto"
|
||||
)
|
||||
|
||||
def test_all_targets_ignored(self) -> None:
|
||||
torch.manual_seed(0)
|
||||
input = torch.randn(512, 16)
|
||||
weight = torch.randn(128, 16)
|
||||
bias = torch.randn(128)
|
||||
target = torch.full((512,), -1, dtype=torch.long)
|
||||
|
||||
for reduction in ("mean", "sum"):
|
||||
self._compare_with_reference(
|
||||
input,
|
||||
weight,
|
||||
target,
|
||||
bias,
|
||||
reduction=reduction,
|
||||
ignore_index=-1,
|
||||
chunking_strategy="batch",
|
||||
)
|
||||
|
||||
def test_reduction_and_options(self) -> None:
|
||||
torch.manual_seed(0)
|
||||
input = torch.randn(3, 4, 8, requires_grad=True)
|
||||
weight = torch.randn(16, 8, requires_grad=True)
|
||||
bias = torch.randn(16, requires_grad=True)
|
||||
target = torch.randint(0, 16, (3, 4))
|
||||
|
||||
for reduction, label_smoothing in itertools.product(
|
||||
["none", "sum", "mean"], [0.0, 0.2]
|
||||
):
|
||||
self._compare_with_reference(
|
||||
input.clone().requires_grad_(),
|
||||
weight.clone().requires_grad_(),
|
||||
target,
|
||||
bias.clone().requires_grad_(),
|
||||
reduction=reduction,
|
||||
label_smoothing=label_smoothing,
|
||||
ignore_index=-1,
|
||||
)
|
||||
|
||||
def test_parameter_validation(self) -> None:
|
||||
x = torch.randn(2, 4)
|
||||
w = torch.randn(8, 4)
|
||||
t = torch.randint(0, 8, (2,))
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "reduction"):
|
||||
F.linear_cross_entropy(x, w, t, reduction="invalid")
|
||||
with self.assertRaisesRegex(ValueError, "label_smoothing"):
|
||||
F.linear_cross_entropy(x, w, t, label_smoothing=-0.1)
|
||||
with self.assertRaisesRegex(ValueError, "label_smoothing"):
|
||||
F.linear_cross_entropy(x, w, t, label_smoothing=1.1)
|
||||
with self.assertRaisesRegex(ValueError, "chunking_strategy"):
|
||||
F.linear_cross_entropy(x, w, t, chunking_strategy="other")
|
||||
|
||||
def test_gradcheck(self) -> None:
|
||||
torch.manual_seed(0)
|
||||
input = torch.randn(3, 5, dtype=torch.double, requires_grad=True)
|
||||
weight = torch.randn(20, 5, dtype=torch.double, requires_grad=True)
|
||||
bias = torch.randn(20, dtype=torch.double, requires_grad=True)
|
||||
target = torch.randint(0, 20, (3,), dtype=torch.long)
|
||||
|
||||
def func(inp, wgt, b):
|
||||
return F.linear_cross_entropy(
|
||||
inp,
|
||||
wgt,
|
||||
target,
|
||||
b,
|
||||
reduction="mean",
|
||||
chunking_strategy="vocab",
|
||||
)
|
||||
|
||||
self.assertTrue(torch.autograd.gradcheck(func, (input, weight, bias)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
@ -401,6 +401,8 @@ CROSS_REF_EXCLUDE_SET = {
|
||||
# CompositeAutogradImplicit
|
||||
# See https://github.com/pytorch/pytorch/issues/81669
|
||||
(None, None, "nn.functional.relu6"),
|
||||
(None, None, "nn.functional.linear_cross_entropy"),
|
||||
(None, None, "aten.linear_cross_entropy"),
|
||||
# This decomp runs before autograd.
|
||||
(None, None, "nn.functional.rrelu"),
|
||||
(None, None, "meshgrid"),
|
||||
|
@ -4682,6 +4682,7 @@ class TestFunctionalTracing(JitTestCase):
|
||||
"celu": CONTROL_FLOW,
|
||||
"cosine_embedding_loss": CONTROL_FLOW,
|
||||
"cross_entropy": CONTROL_FLOW,
|
||||
"linear_cross_entropy": CONTROL_FLOW,
|
||||
"ctc_loss": CONTROL_FLOW,
|
||||
"dropout": CONTROL_FLOW,
|
||||
"dropout1d": CONTROL_FLOW,
|
||||
|
@ -2027,6 +2027,12 @@
|
||||
+ binary_cross_entropy_with_logits_target_backward(target_t, self_p, target_p, weight, pos_weight, at::Reduction::None),
|
||||
reduction)"
|
||||
|
||||
- name: linear_cross_entropy(Tensor input, Tensor weight, Tensor target, Tensor? bias=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0, str chunking_strategy="auto") -> Tensor
|
||||
input: std::get<0>(linear_cross_entropy_backward_symint(grad, input, weight, target, bias, reduction, ignore_index, label_smoothing, chunking_strategy))
|
||||
weight: std::get<1>(linear_cross_entropy_backward_symint(grad, input, weight, target, bias, reduction, ignore_index, label_smoothing, chunking_strategy))
|
||||
bias: std::get<2>(linear_cross_entropy_backward_symint(grad, input, weight, target, bias, reduction, ignore_index, label_smoothing, chunking_strategy)).value_or(at::Tensor())
|
||||
|
||||
|
||||
- name: embedding(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor
|
||||
indices: non_differentiable
|
||||
weight: embedding_backward_symint(grad, indices, weight.sym_size(0), padding_idx, scale_grad_by_freq, sparse)
|
||||
|
@ -39,6 +39,11 @@ from torch.utils._pytree import tree_map
|
||||
|
||||
DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined]
|
||||
|
||||
|
||||
_LINEAR_CROSS_ENTROPY_NAIVE: Optional[Callable[..., Tensor]] = getattr(
|
||||
F, "_linear_cross_entropy_naive", None
|
||||
)
|
||||
|
||||
# None of these functions are publicly accessible; get at them
|
||||
# from torch._decomps
|
||||
__all__: list[str] = []
|
||||
@ -655,6 +660,166 @@ def binary_cross_entropy_backward(
|
||||
return result
|
||||
|
||||
|
||||
@register_decomposition(aten.linear_cross_entropy)
|
||||
@out_wrapper()
|
||||
def linear_cross_entropy(
|
||||
input: Tensor,
|
||||
weight: Tensor,
|
||||
target: Tensor,
|
||||
bias: Optional[Tensor] = None,
|
||||
reduction: int = Reduction.MEAN.value,
|
||||
ignore_index: int = -100,
|
||||
label_smoothing: float = 0.0,
|
||||
chunking_strategy: str = "auto",
|
||||
) -> Tensor:
|
||||
logits = aten.linear.default(input, weight, bias)
|
||||
logits_flat = aten.reshape.default(logits, [-1, logits.size(-1)])
|
||||
target_flat = aten.reshape.default(target, [-1])
|
||||
if target.dtype.is_floating_point:
|
||||
if _LINEAR_CROSS_ENTROPY_NAIVE is None:
|
||||
raise RuntimeError(
|
||||
"linear_cross_entropy decomposition requires "
|
||||
"torch.nn.functional._linear_cross_entropy_naive"
|
||||
)
|
||||
|
||||
reduction_str = (
|
||||
"mean"
|
||||
if reduction == Reduction.MEAN.value
|
||||
else "sum"
|
||||
if reduction == Reduction.SUM.value
|
||||
else "none"
|
||||
)
|
||||
|
||||
return _LINEAR_CROSS_ENTROPY_NAIVE( # type: ignore[misc]
|
||||
input,
|
||||
weight,
|
||||
target,
|
||||
bias,
|
||||
reduction_str,
|
||||
ignore_index,
|
||||
label_smoothing,
|
||||
)
|
||||
target_indices = aten.to.dtype(target_flat, torch.long)
|
||||
loss = aten.cross_entropy_loss.default(
|
||||
logits_flat,
|
||||
target_indices,
|
||||
None,
|
||||
reduction,
|
||||
ignore_index,
|
||||
label_smoothing,
|
||||
)
|
||||
if reduction == Reduction.NONE.value:
|
||||
loss = aten.reshape.default(loss, target.shape)
|
||||
return loss
|
||||
|
||||
|
||||
@register_decomposition(aten.linear_cross_entropy_backward)
|
||||
def linear_cross_entropy_backward(
|
||||
grad_output: Tensor,
|
||||
input: Tensor,
|
||||
weight: Tensor,
|
||||
target: Tensor,
|
||||
bias: Optional[Tensor] = None,
|
||||
reduction: int = Reduction.MEAN.value,
|
||||
ignore_index: int = -100,
|
||||
label_smoothing: float = 0.0,
|
||||
chunking_strategy: str = "auto",
|
||||
) -> tuple[Tensor, Tensor, Optional[Tensor]]:
|
||||
logits = aten.linear.default(input, weight, bias)
|
||||
logits_flat = aten.reshape.default(logits, [-1, logits.size(-1)])
|
||||
input_flat = aten.reshape.default(input, [-1, input.size(-1)])
|
||||
target_flat = aten.reshape.default(target, [-1])
|
||||
|
||||
if target.dtype.is_floating_point:
|
||||
if _LINEAR_CROSS_ENTROPY_NAIVE is None:
|
||||
raise RuntimeError(
|
||||
"linear_cross_entropy decomposition requires "
|
||||
"torch.nn.functional._linear_cross_entropy_naive"
|
||||
)
|
||||
|
||||
reduction_str = (
|
||||
"mean"
|
||||
if reduction == Reduction.MEAN.value
|
||||
else "sum"
|
||||
if reduction == Reduction.SUM.value
|
||||
else "none"
|
||||
)
|
||||
|
||||
ref = _LINEAR_CROSS_ENTROPY_NAIVE( # type: ignore[misc]
|
||||
input,
|
||||
weight,
|
||||
target,
|
||||
bias,
|
||||
reduction_str,
|
||||
ignore_index,
|
||||
label_smoothing,
|
||||
)
|
||||
grad_inputs = torch.autograd.grad(
|
||||
ref,
|
||||
(input, weight) if bias is None else (input, weight, bias),
|
||||
grad_output,
|
||||
retain_graph=True,
|
||||
allow_unused=True,
|
||||
)
|
||||
grad_input = (
|
||||
grad_inputs[0] if grad_inputs[0] is not None else torch.zeros_like(input)
|
||||
)
|
||||
grad_weight = (
|
||||
grad_inputs[1] if grad_inputs[1] is not None else torch.zeros_like(weight)
|
||||
)
|
||||
grad_bias = grad_inputs[2] if bias is not None else None
|
||||
return grad_input, grad_weight, grad_bias
|
||||
|
||||
target_indices = aten.to.dtype(target_flat, torch.long)
|
||||
|
||||
vocab_size = weight.size(0)
|
||||
prob = aten.softmax.default(logits_flat, -1)
|
||||
valid_mask = aten.ne(target_indices, ignore_index)
|
||||
safe_targets = aten.where(
|
||||
valid_mask, target_indices, aten.zeros_like(target_indices)
|
||||
)
|
||||
target_one_hot = aten.zeros_like(prob)
|
||||
target_one_hot = aten.scatter.value(
|
||||
target_one_hot,
|
||||
1,
|
||||
aten.reshape.default(safe_targets, [-1, 1]),
|
||||
1.0,
|
||||
)
|
||||
mask = aten.unsqueeze(valid_mask.to(prob.dtype), 1)
|
||||
target_one_hot = target_one_hot * mask
|
||||
|
||||
if label_smoothing > 0.0:
|
||||
smoothing = label_smoothing
|
||||
uniform = smoothing / float(vocab_size) if vocab_size > 0 else 0.0
|
||||
target_dist = target_one_hot * (1.0 - smoothing)
|
||||
target_dist = target_dist + mask * uniform
|
||||
else:
|
||||
target_dist = target_one_hot
|
||||
|
||||
grad_logits = (prob - target_dist) * mask
|
||||
grad_output_tensor = grad_output.to(grad_logits.dtype)
|
||||
|
||||
if reduction == Reduction.NONE.value:
|
||||
grad_output_flat = aten.reshape.default(grad_output_tensor, [-1, 1])
|
||||
grad_logits = grad_logits * grad_output_flat
|
||||
elif reduction == Reduction.SUM.value:
|
||||
grad_logits = grad_logits * grad_output_tensor
|
||||
else:
|
||||
valid_count = aten.sum.default(mask)
|
||||
scale = grad_output_tensor / aten.clamp_min(valid_count, 1.0)
|
||||
grad_logits = grad_logits * scale
|
||||
|
||||
grad_input_flat = aten.mm(grad_logits, weight)
|
||||
grad_input = aten.reshape.default(grad_input_flat, input.shape)
|
||||
grad_weight = aten.mm(aten.transpose(grad_logits, 0, 1), input_flat)
|
||||
if bias is not None:
|
||||
grad_bias_result: Optional[Tensor] = aten.sum.dim_IntList(grad_logits, [0])
|
||||
else:
|
||||
grad_bias_result = None
|
||||
|
||||
return grad_input, grad_weight, grad_bias_result
|
||||
|
||||
|
||||
@register_decomposition(aten.soft_margin_loss)
|
||||
@out_wrapper()
|
||||
@pw_cast_for_opmath
|
||||
|
@ -3503,6 +3503,156 @@ def cross_entropy(
|
||||
)
|
||||
|
||||
|
||||
def _linear_cross_entropy_naive(
|
||||
input: Tensor,
|
||||
weight: Tensor,
|
||||
target: Tensor,
|
||||
bias: Optional[Tensor],
|
||||
reduction: str,
|
||||
ignore_index: int,
|
||||
label_smoothing: float,
|
||||
) -> Tensor:
|
||||
logits = linear(input, weight, bias)
|
||||
logits_flat = logits.reshape(-1, logits.size(-1))
|
||||
target_flat = target.reshape(-1)
|
||||
loss = cross_entropy(
|
||||
logits_flat,
|
||||
target_flat,
|
||||
reduction=reduction,
|
||||
ignore_index=ignore_index,
|
||||
label_smoothing=label_smoothing,
|
||||
)
|
||||
if reduction == "none":
|
||||
loss = loss.reshape(target.shape)
|
||||
return loss
|
||||
|
||||
|
||||
def linear_cross_entropy(
|
||||
input: Tensor,
|
||||
weight: Tensor,
|
||||
target: Tensor,
|
||||
bias: Optional[Tensor] = None,
|
||||
reduction: str = "mean",
|
||||
ignore_index: int = -100,
|
||||
label_smoothing: float = 0.0,
|
||||
chunking_strategy: str = "auto",
|
||||
) -> Tensor:
|
||||
r"""Compute fused linear transformation and cross entropy loss on CPU.
|
||||
|
||||
This is a convenience wrapper around :func:`linear` followed by
|
||||
:func:`cross_entropy`. When the inputs live on CPU it uses a fused ATen
|
||||
kernel that chunks the vocabulary or batch dimension to avoid materialising
|
||||
large logit tensors. For other devices it falls back to the unfused
|
||||
composition.
|
||||
"""
|
||||
if has_torch_function_variadic(input, weight, target, bias):
|
||||
return handle_torch_function(
|
||||
linear_cross_entropy,
|
||||
(input, weight, target, bias),
|
||||
input,
|
||||
weight,
|
||||
target,
|
||||
bias=bias,
|
||||
reduction=reduction,
|
||||
ignore_index=ignore_index,
|
||||
label_smoothing=label_smoothing,
|
||||
chunking_strategy=chunking_strategy,
|
||||
)
|
||||
|
||||
if not isinstance(reduction, str):
|
||||
if hasattr(reduction, "node"):
|
||||
from torch.fx.proxy import TraceError
|
||||
|
||||
raise TraceError(
|
||||
"symbolically traced variables cannot be used as inputs to control flow"
|
||||
)
|
||||
raise ValueError(
|
||||
f"reduction must be one of ('mean', 'sum', 'none'), got '{reduction}'"
|
||||
)
|
||||
|
||||
if reduction not in ("mean", "sum", "none"):
|
||||
raise ValueError(
|
||||
f"reduction must be one of ('mean', 'sum', 'none'), got '{reduction}'"
|
||||
)
|
||||
|
||||
if not (0.0 <= label_smoothing <= 1.0):
|
||||
raise ValueError(
|
||||
f"label_smoothing must be between 0.0 and 1.0, got {label_smoothing}"
|
||||
)
|
||||
|
||||
if chunking_strategy not in ("auto", "vocab", "batch", "none"):
|
||||
raise ValueError(
|
||||
"chunking_strategy must be one of ('auto', 'vocab', 'batch', 'none'), "
|
||||
f"got '{chunking_strategy}'"
|
||||
)
|
||||
|
||||
if (
|
||||
input.device.type != "cpu"
|
||||
or weight.device != input.device
|
||||
or target.device != input.device
|
||||
or (bias is not None and bias.device != input.device)
|
||||
):
|
||||
result = _linear_cross_entropy_naive(
|
||||
input,
|
||||
weight,
|
||||
target,
|
||||
bias,
|
||||
reduction,
|
||||
ignore_index,
|
||||
label_smoothing,
|
||||
)
|
||||
else:
|
||||
op = torch.ops.aten.linear_cross_entropy.default
|
||||
# Only exercise the fused path when the operator is actually built for
|
||||
# this runtime; otherwise fall back to the explicit composition so the
|
||||
# behaviour matches older binaries.
|
||||
if not op.has_kernel_for_dispatch_key("CPU"):
|
||||
result = _linear_cross_entropy_naive(
|
||||
input,
|
||||
weight,
|
||||
target,
|
||||
bias,
|
||||
reduction,
|
||||
ignore_index,
|
||||
label_smoothing,
|
||||
)
|
||||
else:
|
||||
reduction_enum = _Reduction.get_enum(reduction)
|
||||
needs_grad = torch.is_grad_enabled() and (
|
||||
input.requires_grad
|
||||
or weight.requires_grad
|
||||
or (bias is not None and bias.requires_grad)
|
||||
)
|
||||
# Some downstream builds may omit the generated autograd kernel; if
|
||||
# that happens we still provide gradients by delegating to the
|
||||
# unfused implementation.
|
||||
if needs_grad and not op.has_kernel_for_dispatch_key("AutogradCPU"):
|
||||
result = _linear_cross_entropy_naive(
|
||||
input,
|
||||
weight,
|
||||
target,
|
||||
bias,
|
||||
reduction,
|
||||
ignore_index,
|
||||
label_smoothing,
|
||||
)
|
||||
else:
|
||||
result = torch.ops.aten.linear_cross_entropy(
|
||||
input,
|
||||
weight,
|
||||
target,
|
||||
bias,
|
||||
reduction_enum,
|
||||
ignore_index,
|
||||
label_smoothing,
|
||||
chunking_strategy,
|
||||
)
|
||||
|
||||
if reduction == "none":
|
||||
return result.reshape(target.shape)
|
||||
return result
|
||||
|
||||
|
||||
def binary_cross_entropy(
|
||||
input: Tensor,
|
||||
target: Tensor,
|
||||
|
@ -558,6 +558,9 @@ def get_testing_overrides() -> dict[Callable, Callable]:
|
||||
torch.ctc_loss: (
|
||||
lambda log_probs, targets, input_lengths, target_lengths, blank=0, reduction="mean", zero_infinity=False: -1
|
||||
),
|
||||
torch.linear_cross_entropy: (
|
||||
lambda input, weight, target, bias=None, reduction=1, ignore_index=-100, label_smoothing=0.0, chunking_strategy="auto": -1 # noqa: B950
|
||||
),
|
||||
torch.cummax: lambda input, dim, out=None: -1,
|
||||
torch.cummin: lambda input, dim, out=None: -1,
|
||||
torch.cumprod: lambda input, dim, out=None, dtype=None: -1,
|
||||
@ -866,6 +869,9 @@ def get_testing_overrides() -> dict[Callable, Callable]:
|
||||
torch.nn.functional.cross_entropy: (
|
||||
lambda input, target, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction="mean", label_smoothing=0.0: -1 # noqa: B950
|
||||
),
|
||||
torch.nn.functional.linear_cross_entropy: (
|
||||
lambda input, weight, target, bias=None, reduction="mean", ignore_index=-100, label_smoothing=0.0, chunking_strategy="auto": -1 # noqa: B950
|
||||
),
|
||||
torch.nn.functional.ctc_loss: (
|
||||
lambda log_probs, targets, input_lengths, target_lengths, blank=0, reduction="mean", zero_infinity=False: -1
|
||||
),
|
||||
|
@ -48,6 +48,7 @@ import torch._refs.nn.functional
|
||||
import torch._refs.special
|
||||
import torch._refs.linalg
|
||||
import torch._prims as prims # noqa: F401
|
||||
import torch.nn.functional as F
|
||||
from torch.utils import _pytree as pytree
|
||||
|
||||
|
||||
@ -6741,6 +6742,78 @@ def sample_inputs_cross_entropy(op_info, device, dtype, requires_grad, **kwargs)
|
||||
yield SampleInput(input, target, **kwargs)
|
||||
|
||||
|
||||
def sample_inputs_linear_cross_entropy(op_info, device, dtype, requires_grad, **kwargs):
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
make_target = partial(make_tensor, device=device, dtype=torch.long, requires_grad=False)
|
||||
|
||||
# Force vocabulary chunking with a large vocab size.
|
||||
vocab_hidden = 16
|
||||
vocab_size = 4097
|
||||
input_vocab = make_input((2, vocab_hidden))
|
||||
weight_vocab = make_weight((vocab_size, vocab_hidden))
|
||||
target_vocab = make_target((2,), low=0, high=vocab_size)
|
||||
yield SampleInput(
|
||||
input_vocab,
|
||||
args=(weight_vocab, target_vocab),
|
||||
kwargs={"chunking_strategy": "vocab"},
|
||||
)
|
||||
|
||||
# 3D input to trigger flattening logic (batch, sequence, hidden).
|
||||
seq_len = 5
|
||||
input_seq = make_input((3, seq_len, vocab_hidden))
|
||||
target_seq = make_target((3, seq_len), low=0, high=vocab_size)
|
||||
yield SampleInput(
|
||||
input_seq,
|
||||
args=(weight_vocab, target_seq),
|
||||
kwargs={"chunking_strategy": "vocab"},
|
||||
)
|
||||
|
||||
# Force batch chunking with a large flattened batch.
|
||||
batch_hidden = 8
|
||||
batch_size = 1500
|
||||
input_batch = make_input((batch_size, batch_hidden))
|
||||
weight_batch = make_weight((64, batch_hidden))
|
||||
target_batch = make_target((batch_size,), low=0, high=64)
|
||||
yield SampleInput(
|
||||
input_batch,
|
||||
args=(weight_batch, target_batch),
|
||||
kwargs={"chunking_strategy": "batch"},
|
||||
)
|
||||
|
||||
# 3D batch chunking (batch, seq, hidden) to exercise large flattened rows.
|
||||
seq_len_batch = 4
|
||||
batch_batch = 200
|
||||
input_batch_seq = make_input((batch_batch, seq_len_batch, batch_hidden))
|
||||
target_batch_seq = make_target((batch_batch, seq_len_batch), low=0, high=64)
|
||||
yield SampleInput(
|
||||
input_batch_seq,
|
||||
args=(weight_batch, target_batch_seq),
|
||||
kwargs={"chunking_strategy": "batch"},
|
||||
)
|
||||
|
||||
|
||||
def reference_linear_cross_entropy(
|
||||
input,
|
||||
weight,
|
||||
target,
|
||||
bias=None,
|
||||
reduction="mean",
|
||||
ignore_index=-100,
|
||||
label_smoothing=0.0,
|
||||
chunking_strategy="auto",
|
||||
):
|
||||
return F._linear_cross_entropy_naive(
|
||||
input,
|
||||
weight,
|
||||
target,
|
||||
bias,
|
||||
reduction,
|
||||
ignore_index,
|
||||
label_smoothing,
|
||||
)
|
||||
|
||||
|
||||
def sample_inputs_logit(op_info, device, dtype, requires_grad, **kwargs):
|
||||
low, high = op_info.domain
|
||||
|
||||
@ -14778,6 +14851,55 @@ op_db: list[OpInfo] = [
|
||||
|
||||
)
|
||||
),
|
||||
OpInfo(
|
||||
"nn.functional.linear_cross_entropy",
|
||||
aten_name="linear_cross_entropy",
|
||||
dtypes=floating_types_and(torch.float16, torch.bfloat16),
|
||||
sample_inputs_func=sample_inputs_linear_cross_entropy,
|
||||
ref=reference_linear_cross_entropy,
|
||||
supports_out=False,
|
||||
supports_forward_ad=False,
|
||||
supports_fwgrad_bwgrad=False,
|
||||
decorators=(onlyCPU,),
|
||||
skips=(
|
||||
DecorateInfo(unittest.skip("linear_cross_entropy vmap support pending"),
|
||||
"TestVmapOperatorsOpInfo",
|
||||
"test_op_has_batch_rule"),
|
||||
DecorateInfo(unittest.skip("linear_cross_entropy vmap support pending"),
|
||||
"TestVmapOperatorsOpInfoCPU",
|
||||
"test_op_has_batch_rule"),
|
||||
DecorateInfo(unittest.skip("linear_cross_entropy vmap support pending"),
|
||||
"TestVmapOperatorsOpInfo",
|
||||
"test_vmap_exhaustive"),
|
||||
DecorateInfo(unittest.skip("linear_cross_entropy vmap support pending"),
|
||||
"TestVmapOperatorsOpInfoCPU",
|
||||
"test_vmap_exhaustive"),
|
||||
DecorateInfo(unittest.skip("linear_cross_entropy vmap support pending"),
|
||||
"TestOperators",
|
||||
"test_vjpvmap"),
|
||||
DecorateInfo(unittest.skip("linear_cross_entropy vmap support pending"),
|
||||
"TestOperatorsCPU",
|
||||
"test_vjpvmap"),
|
||||
DecorateInfo(unittest.skip("linear_cross_entropy forward-mode AD pending"),
|
||||
"TestFwdGradients",
|
||||
"test_fn_fwgrad_bwgrad"),
|
||||
DecorateInfo(unittest.skip("linear_cross_entropy forward-mode AD pending"),
|
||||
"TestFwdGradientsCPU",
|
||||
"test_fn_fwgrad_bwgrad"),
|
||||
DecorateInfo(unittest.skip("linear_cross_entropy JIT compliance pending"),
|
||||
"TestJit",
|
||||
"test_variant_consistency_jit"),
|
||||
DecorateInfo(unittest.skip("linear_cross_entropy JIT compliance pending"),
|
||||
"TestJitCPU",
|
||||
"test_variant_consistency_jit"),
|
||||
DecorateInfo(unittest.skip("linear_cross_entropy DTensor support pending"),
|
||||
"TestDTensorOps",
|
||||
"test_dtensor_op_db"),
|
||||
DecorateInfo(unittest.skip("linear_cross_entropy DTensor support pending"),
|
||||
"TestDTensorOpsCPU",
|
||||
"test_dtensor_op_db"),
|
||||
),
|
||||
),
|
||||
OpInfo('nn.functional.normalize',
|
||||
dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
|
||||
sample_inputs_func=sample_inputs_normalize,
|
||||
|
Reference in New Issue
Block a user