1- #include < ATen/Operators.h>
2- #include < torch/all.h>
3- #include < torch/library.h>
1+ // LibTorch Stable ABI version of CUDA custom operators
2+ // This file uses the stable API for cross-version compatibility.
3+ // See: https://pytorch.org/docs/main/notes/libtorch_stable_abi.html
4+
5+ #include < torch/csrc/stable/library.h>
6+ #include < torch/csrc/stable/ops.h>
7+ #include < torch/csrc/stable/tensor.h>
8+ #include < torch/csrc/stable/accelerator.h>
9+ #include < torch/headeronly/core/ScalarType.h>
10+ #include < torch/headeronly/macros/Macros.h>
11+
12+ #include < torch/csrc/inductor/aoti_torch/c/shim.h>
413
514#include < cuda.h>
615#include < cuda_runtime.h>
7- #include < ATen/cuda/CUDAContext.h>
816
917namespace extension_cpp {
1018
@@ -13,21 +21,35 @@ __global__ void muladd_kernel(int numel, const float* a, const float* b, float c
1321 if (idx < numel) result[idx] = a[idx] * b[idx] + c;
1422}
1523
16- at::Tensor mymuladd_cuda (const at::Tensor& a, const at::Tensor& b, double c) {
17- TORCH_CHECK (a.sizes () == b.sizes ());
18- TORCH_CHECK (a.dtype () == at::kFloat );
19- TORCH_CHECK (b.dtype () == at::kFloat );
20- TORCH_INTERNAL_ASSERT (a.device ().type () == at::DeviceType::CUDA);
21- TORCH_INTERNAL_ASSERT (b.device ().type () == at::DeviceType::CUDA);
22- at::Tensor a_contig = a.contiguous ();
23- at::Tensor b_contig = b.contiguous ();
24- at::Tensor result = at::empty (a_contig.sizes (), a_contig.options ());
25- const float * a_ptr = a_contig.data_ptr <float >();
26- const float * b_ptr = b_contig.data_ptr <float >();
27- float * result_ptr = result.data_ptr <float >();
24+ torch::stable::Tensor mymuladd_cuda (
25+ const torch::stable::Tensor& a,
26+ const torch::stable::Tensor& b,
27+ double c) {
28+ STD_TORCH_CHECK (a.sizes ().equals (b.sizes ()));
29+ STD_TORCH_CHECK (a.scalar_type () == torch::headeronly::ScalarType::Float);
30+ STD_TORCH_CHECK (b.scalar_type () == torch::headeronly::ScalarType::Float);
31+ STD_TORCH_CHECK (a.device ().type () == torch::headeronly::DeviceType::CUDA);
32+ STD_TORCH_CHECK (b.device ().type () == torch::headeronly::DeviceType::CUDA);
33+
34+ torch::stable::Tensor a_contig = torch::stable::contiguous (a);
35+ torch::stable::Tensor b_contig = torch::stable::contiguous (b);
36+ torch::stable::Tensor result = torch::stable::empty_like (a_contig);
37+
38+ const float * a_ptr = a_contig.const_data_ptr <float >();
39+ const float * b_ptr = b_contig.const_data_ptr <float >();
40+ float * result_ptr = result.mutable_data_ptr <float >();
2841
2942 int numel = a_contig.numel ();
30- cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
43+
44+ // For now, we rely on the raw shim API to get the current CUDA stream.
45+ // This will be improved in a future release.
46+ // When using a raw shim API, we need to use TORCH_ERROR_CODE_CHECK to
47+ // check the error code and throw an appropriate runtime_error otherwise.
48+ void * stream_ptr = nullptr ;
49+ TORCH_ERROR_CODE_CHECK (
50+ aoti_torch_get_current_cuda_stream (a.get_device_index (), &stream_ptr));
51+ cudaStream_t stream = static_cast <cudaStream_t>(stream_ptr);
52+
3153 muladd_kernel<<<(numel+255 )/256 , 256 , 0 , stream>>> (numel, a_ptr, b_ptr, c, result_ptr);
3254 return result;
3355}
@@ -37,20 +59,34 @@ __global__ void mul_kernel(int numel, const float* a, const float* b, float* res
3759 if (idx < numel) result[idx] = a[idx] * b[idx];
3860}
3961
40- at::Tensor mymul_cuda (const at::Tensor& a, const at::Tensor& b) {
41- TORCH_CHECK (a.sizes () == b.sizes ());
42- TORCH_CHECK (a.dtype () == at::kFloat );
43- TORCH_CHECK (b.dtype () == at::kFloat );
44- TORCH_INTERNAL_ASSERT (a.device ().type () == at::DeviceType::CUDA);
45- TORCH_INTERNAL_ASSERT (b.device ().type () == at::DeviceType::CUDA);
46- at::Tensor a_contig = a.contiguous ();
47- at::Tensor b_contig = b.contiguous ();
48- at::Tensor result = at::empty (a_contig.sizes (), a_contig.options ());
49- const float * a_ptr = a_contig.data_ptr <float >();
50- const float * b_ptr = b_contig.data_ptr <float >();
51- float * result_ptr = result.data_ptr <float >();
62+ torch::stable::Tensor mymul_cuda (
63+ const torch::stable::Tensor& a,
64+ const torch::stable::Tensor& b) {
65+ STD_TORCH_CHECK (a.sizes ().equals (b.sizes ()));
66+ STD_TORCH_CHECK (a.scalar_type () == torch::headeronly::ScalarType::Float);
67+ STD_TORCH_CHECK (b.scalar_type () == torch::headeronly::ScalarType::Float);
68+ STD_TORCH_CHECK (a.device ().type () == torch::headeronly::DeviceType::CUDA);
69+ STD_TORCH_CHECK (b.device ().type () == torch::headeronly::DeviceType::CUDA);
70+
71+ torch::stable::Tensor a_contig = torch::stable::contiguous (a);
72+ torch::stable::Tensor b_contig = torch::stable::contiguous (b);
73+ torch::stable::Tensor result = torch::stable::empty_like (a_contig);
74+
75+ const float * a_ptr = a_contig.const_data_ptr <float >();
76+ const float * b_ptr = b_contig.const_data_ptr <float >();
77+ float * result_ptr = result.mutable_data_ptr <float >();
78+
5279 int numel = a_contig.numel ();
53- cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
80+
81+ // For now, we rely on the raw shim API to get the current CUDA stream.
82+ // This will be improved in a future release.
83+ // When using a raw shim API, we need to use TORCH_ERROR_CODE_CHECK to
84+ // check the error code and throw an appropriate runtime_error otherwise.
85+ void * stream_ptr = nullptr ;
86+ TORCH_ERROR_CODE_CHECK (
87+ aoti_torch_get_current_cuda_stream (a.get_device_index (), &stream_ptr));
88+ cudaStream_t stream = static_cast <cudaStream_t>(stream_ptr);
89+
5490 mul_kernel<<<(numel+255 )/256 , 256 , 0 , stream>>> (numel, a_ptr, b_ptr, result_ptr);
5591 return result;
5692}
@@ -60,32 +96,47 @@ __global__ void add_kernel(int numel, const float* a, const float* b, float* res
6096 if (idx < numel) result[idx] = a[idx] + b[idx];
6197}
6298
63- void myadd_out_cuda (const at::Tensor& a, const at::Tensor& b, at::Tensor& out) {
64- TORCH_CHECK (a.sizes () == b.sizes ());
65- TORCH_CHECK (b.sizes () == out.sizes ());
66- TORCH_CHECK (a.dtype () == at::kFloat );
67- TORCH_CHECK (b.dtype () == at::kFloat );
68- TORCH_CHECK (out.dtype () == at::kFloat );
69- TORCH_CHECK (out.is_contiguous ());
70- TORCH_INTERNAL_ASSERT (a.device ().type () == at::DeviceType::CUDA);
71- TORCH_INTERNAL_ASSERT (b.device ().type () == at::DeviceType::CUDA);
72- TORCH_INTERNAL_ASSERT (out.device ().type () == at::DeviceType::CUDA);
73- at::Tensor a_contig = a.contiguous ();
74- at::Tensor b_contig = b.contiguous ();
75- const float * a_ptr = a_contig.data_ptr <float >();
76- const float * b_ptr = b_contig.data_ptr <float >();
77- float * result_ptr = out.data_ptr <float >();
99+ // An example of an operator that mutates one of its inputs.
100+ void myadd_out_cuda (
101+ const torch::stable::Tensor& a,
102+ const torch::stable::Tensor& b,
103+ torch::stable::Tensor& out) {
104+ STD_TORCH_CHECK (a.sizes ().equals (b.sizes ()));
105+ STD_TORCH_CHECK (b.sizes ().equals (out.sizes ()));
106+ STD_TORCH_CHECK (a.scalar_type () == torch::headeronly::ScalarType::Float);
107+ STD_TORCH_CHECK (b.scalar_type () == torch::headeronly::ScalarType::Float);
108+ STD_TORCH_CHECK (out.scalar_type () == torch::headeronly::ScalarType::Float);
109+ STD_TORCH_CHECK (out.is_contiguous ());
110+ STD_TORCH_CHECK (a.device ().type () == torch::headeronly::DeviceType::CUDA);
111+ STD_TORCH_CHECK (b.device ().type () == torch::headeronly::DeviceType::CUDA);
112+ STD_TORCH_CHECK (out.device ().type () == torch::headeronly::DeviceType::CUDA);
113+
114+ torch::stable::Tensor a_contig = torch::stable::contiguous (a);
115+ torch::stable::Tensor b_contig = torch::stable::contiguous (b);
116+
117+ const float * a_ptr = a_contig.const_data_ptr <float >();
118+ const float * b_ptr = b_contig.const_data_ptr <float >();
119+ float * result_ptr = out.mutable_data_ptr <float >();
120+
78121 int numel = a_contig.numel ();
79- cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
122+
123+ // For now, we rely on the raw shim API to get the current CUDA stream.
124+ // This will be improved in a future release.
125+ // When using a raw shim API, we need to use TORCH_ERROR_CODE_CHECK to
126+ // check the error code and throw an appropriate runtime_error otherwise.
127+ void * stream_ptr = nullptr ;
128+ TORCH_ERROR_CODE_CHECK (
129+ aoti_torch_get_current_cuda_stream (a.get_device_index (), &stream_ptr));
130+ cudaStream_t stream = static_cast <cudaStream_t>(stream_ptr);
131+
80132 add_kernel<<<(numel+255 )/256 , 256 , 0 , stream>>> (numel, a_ptr, b_ptr, result_ptr);
81133}
82134
83-
84135// Registers CUDA implementations for mymuladd, mymul, myadd_out
85- TORCH_LIBRARY_IMPL (extension_cpp, CUDA, m) {
86- m.impl (" mymuladd" , &mymuladd_cuda);
87- m.impl (" mymul" , &mymul_cuda);
88- m.impl (" myadd_out" , &myadd_out_cuda);
136+ STABLE_TORCH_LIBRARY_IMPL (extension_cpp, CUDA, m) {
137+ m.impl (" mymuladd" , TORCH_BOX ( &mymuladd_cuda) );
138+ m.impl (" mymul" , TORCH_BOX ( &mymul_cuda) );
139+ m.impl (" myadd_out" , TORCH_BOX ( &myadd_out_cuda) );
89140}
90141
91142}
0 commit comments