Skip to content

Conversation

@Zantares
Copy link

This PR optimizes the as_strided_copy fast path to support offset.

Negative case (happens in the default path of torch LSTM):

import torch
import torch_xla


def torch_xla_chunk_example(input_tensor):
    chunks = torch.unsafe_chunk(input_tensor, 4, dim=1)
    result0 = chunks[0]
    result1 = chunks[1]
    return torch.sigmoid(result0), torch.tanh(result1)


def create_example():
    device = torch_xla.device()
    
    torch.manual_seed(12)
    input_tensor = torch.randn(10, 20, 30, dtype=torch.float32)
    
    result0, result1 = torch_xla_chunk_example(input_tensor.to(device)) 
    torch_xla.sync()
    print(f"Chunk 0:\n{result0}\n")
    print(f"Chunk 1:\n{result1}\n")
    
    return result0


if __name__ == "__main__":
    create_example()

HLO before opt:

ENTRY %IrToHlo.19 (p0.1: f32[10,20,30], p1.4: s64[10,5,30]) -> (f32[10,5,30], f32[10,5,30]) {                                                                                   
  ...
  %slice.2 = f32[10,5,30]{2,1,0} slice(f32[10,20,30]{2,1,0} %p0.1), slice={[0:10], [0:5], [0:30]}, metadata={op_type="xla__select" op_name="xla__select"}
  ...
  %constant.10 = s64[] constant(0), metadata={op_type="aten__take" op_name="aten__take"}
  %broadcast.11 = s64[1500]{0} broadcast(s64[] %constant.10), dimensions={}, metadata={op_type="aten__take"}
  %compare.12 = pred[1500]{0} compare(s64[1500]{0} %reshape.6, s64[1500]{0} %broadcast.11), direction=GE, metadata={op_type="aten__take" op_name="aten__take"}
  %constant.7 = s64[] constant(6000), metadata={op_type="aten__take" op_name="aten__take"}
  %broadcast.8 = s64[1500]{0} broadcast(s64[] %constant.7), dimensions={}, metadata={op_type="aten__take"}
  %add.9 = s64[1500]{0} add(s64[1500]{0} %reshape.6, s64[1500]{0} %broadcast.8), metadata={op_type="aten__take"}
  %select.13 = s64[1500]{0} select(pred[1500]{0} %compare.12, s64[1500]{0} %reshape.6, s64[1500]{0} %add.9), metadata={op_type="aten__take" op_name="aten__take"}
  %convert.14 = u32[1500]{0} convert(s64[1500]{0} %select.13), metadata={op_type="aten__take" op_name="aten__take"}
  %gather.15 = f32[1500]{0} gather(f32[6000]{0} %reshape.5, u32[1500]{0} %convert.14), offset_dims={}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1}, metadata={op_type="aten__take" op_name="aten__take"}
  ...
}
  • result0 W/O offset used slice
  • result1 W/ offset used take which contains many ops

The 2 operations are very similar except the offset part. In fact the 2nd result can be directly sliced after adding the offset to the start parameter of slice. This PR is to support such optimization in as_strided_copy when the offset happens in the same dimension of slice.

The optimized HLO:

ENTRY %IrToHlo.7 (p0.1: f32[10,20,30]) -> (f32[10,5,30], f32[10,5,30]) {
  ...
  %slice.2 = f32[10,5,30]{2,1,0} slice(f32[10,20,30]{2,1,0} %p0.1), slice={[0:10], [0:5], [0:30]}, metadata={op_type="xla__select" op_name="xla__select"}
  ...
  %slice.4 = f32[10,5,30]{2,1,0} slice(f32[10,20,30]{2,1,0} %p0.1), slice={[0:10], [5:10], [0:30]}, metadata={op_type="xla__select" op_name="xla__select"} 
  ...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant