From 8a3cc863f854105b151ce3f7f03008869b56cc97 Mon Sep 17 00:00:00 2001 From: "Teng, Lu" Date: Wed, 10 Dec 2025 20:31:33 +0800 Subject: [PATCH] Optimize as_strided_copy fast path to support offset --- torch_xla/csrc/aten_xla_type.cpp | 34 +++++++++++++++++++++++++++----- 1 file changed, 29 insertions(+), 5 deletions(-) diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 1c855ca8239..d7edf0ff1f4 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -938,9 +938,7 @@ static at::Tensor as_strided_eliminate_one_dim_fast_path( // we can shuffle element in `stride` and `size` and this can result in the // transpose of dimensions, but we don't consider this case here. auto tensor_dim = tensor.sizes(); - if (storage_offset.has_value() && (*storage_offset != 0)) { - return at::Tensor(); - } + bool has_offset = (storage_offset.has_value() && (*storage_offset != 0)); if (tensor_dim.size() != stride.size() && tensor_dim.size() != stride.size() + 1) { return at::Tensor(); @@ -971,6 +969,10 @@ static at::Tensor as_strided_eliminate_one_dim_fast_path( } if (tensor_dim.size() == stride.size() + 1) { + // Don't support offset in this case. + if (has_offset) { + return at::Tensor(); + } for (long i = 0, j = 0; i < size.size(); i++, j++) { if (i == skip_dim) { j++; @@ -1013,9 +1015,31 @@ static at::Tensor as_strided_eliminate_one_dim_fast_path( // stride. K = 1; } + + // Calculate start index in reduce_size_location dimension from storage_offset + int64_t start = 0; + if (has_offset) { + int64_t base_storage_offset = tensor.storage_offset(); + int64_t relative_storage_offset = (*storage_offset) - base_storage_offset; + if (stride[reduce_size_location] > 0) { + start = relative_storage_offset / stride[reduce_size_location]; + } else { + // Negative stride not supported + return at::Tensor(); + } + // Check if start is out of range + if (start < 0 || start >= tensor_dim[reduce_size_location]) { + return at::Tensor(); + } + } + int64_t end = start + size[reduce_size_location] * K; + // Check if end is out of range + if (end > tensor_dim[reduce_size_location]) { + return at::Tensor(); + } XLA_ASSIGN_OR_THROW(XLATensorPtr xla_tensor, bridge::GetXlaTensor(tensor)); - return bridge::AtenFromXlaTensor(tensor_methods::slice( - xla_tensor, reduce_size_location, 0, size[reduce_size_location] * K, K)); + return bridge::AtenFromXlaTensor( + tensor_methods::slice(xla_tensor, reduce_size_location, start, end, K)); } at::Tensor XLANativeFunctions::as_strided_copy(