Skip to content

Commit 6f2afea

Browse files
Migrate extension-cpp to stable API/ABI
1 parent 0ec4969 commit 6f2afea

File tree

7 files changed

+198
-109
lines changed

7 files changed

+198
-109
lines changed

.github/scripts/setup-env.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ pip install --progress-bar=off -r requirements.txt
101101
echo '::endgroup::'
102102

103103
echo '::group::Install extension-cpp'
104-
python setup.py develop
104+
pip install -e . --no-build-isolation
105105
echo '::endgroup::'
106106

107107
echo '::group::Collect environment information'

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ jobs:
1919
- python-version: 3.13
2020
runner: linux.g5.4xlarge.nvidia.gpu
2121
gpu-arch-type: cuda
22-
gpu-arch-version: "12.4"
22+
gpu-arch-version: "12.9"
2323
fail-fast: false
2424
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
2525
permissions:

README.md

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
1-
# C++/CUDA Extensions in PyTorch
1+
# C++/CUDA Extensions in PyTorch with LibTorch Stable ABI
2+
3+
An example of writing a C++/CUDA extension for PyTorch using the [LibTorch Stable ABI](https://pytorch.org/docs/main/notes/libtorch_stable_abi.html).
4+
See [here](https://pytorch.org/tutorials/advanced/cpp_custom_ops.html) for the accompanying tutorial.
25

3-
An example of writing a C++/CUDA extension for PyTorch. See
4-
[here](https://pytorch.org/tutorials/advanced/cpp_custom_ops.html) for the accompanying tutorial.
56
This repo demonstrates how to write an example `extension_cpp.ops.mymuladd`
6-
custom op that has both custom CPU and CUDA kernels.
7+
custom op that has both custom CPU and CUDA kernels, it leverages the LibTorch
8+
Stable ABI to ensure that the extension built can be run with any version of
9+
PyTorch >= 2.10.0.
10+
11+
The examples in this repo work with PyTorch 2.10+. For an example of how to use
12+
the non-stable subset of LibTorch, see [this previous commit](https://github.com/pytorch/extension-cpp/tree/0ec4969c7bc8e15a8456e5eb9d9ca0a7ec15bc95).
713

8-
The examples in this repo work with PyTorch 2.4+.
914

1015
To build:
1116
```

extension_cpp/csrc/cuda/muladd.cu

Lines changed: 102 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
1-
#include <ATen/Operators.h>
2-
#include <torch/all.h>
3-
#include <torch/library.h>
1+
// LibTorch Stable ABI version of CUDA custom operators
2+
// This file uses the stable API for cross-version compatibility.
3+
// See: https://pytorch.org/docs/main/notes/libtorch_stable_abi.html
4+
5+
#include <torch/csrc/stable/library.h>
6+
#include <torch/csrc/stable/ops.h>
7+
#include <torch/csrc/stable/tensor.h>
8+
#include <torch/csrc/stable/accelerator.h>
9+
#include <torch/headeronly/core/ScalarType.h>
10+
#include <torch/headeronly/macros/Macros.h>
11+
12+
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
413

514
#include <cuda.h>
615
#include <cuda_runtime.h>
7-
#include <ATen/cuda/CUDAContext.h>
816

917
namespace extension_cpp {
1018

@@ -13,21 +21,35 @@ __global__ void muladd_kernel(int numel, const float* a, const float* b, float c
1321
if (idx < numel) result[idx] = a[idx] * b[idx] + c;
1422
}
1523

16-
at::Tensor mymuladd_cuda(const at::Tensor& a, const at::Tensor& b, double c) {
17-
TORCH_CHECK(a.sizes() == b.sizes());
18-
TORCH_CHECK(a.dtype() == at::kFloat);
19-
TORCH_CHECK(b.dtype() == at::kFloat);
20-
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CUDA);
21-
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CUDA);
22-
at::Tensor a_contig = a.contiguous();
23-
at::Tensor b_contig = b.contiguous();
24-
at::Tensor result = at::empty(a_contig.sizes(), a_contig.options());
25-
const float* a_ptr = a_contig.data_ptr<float>();
26-
const float* b_ptr = b_contig.data_ptr<float>();
27-
float* result_ptr = result.data_ptr<float>();
24+
torch::stable::Tensor mymuladd_cuda(
25+
const torch::stable::Tensor& a,
26+
const torch::stable::Tensor& b,
27+
double c) {
28+
STD_TORCH_CHECK(a.sizes().equals(b.sizes()));
29+
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Float);
30+
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Float);
31+
STD_TORCH_CHECK(a.device().type() == torch::headeronly::DeviceType::CUDA);
32+
STD_TORCH_CHECK(b.device().type() == torch::headeronly::DeviceType::CUDA);
33+
34+
torch::stable::Tensor a_contig = torch::stable::contiguous(a);
35+
torch::stable::Tensor b_contig = torch::stable::contiguous(b);
36+
torch::stable::Tensor result = torch::stable::empty_like(a_contig);
37+
38+
const float* a_ptr = a_contig.const_data_ptr<float>();
39+
const float* b_ptr = b_contig.const_data_ptr<float>();
40+
float* result_ptr = result.mutable_data_ptr<float>();
2841

2942
int numel = a_contig.numel();
30-
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
43+
44+
// For now, we rely on the raw shim API to get the current CUDA stream.
45+
// This will be improved in a future release.
46+
// When using a raw shim API, we need to use TORCH_ERROR_CODE_CHECK to
47+
// check the error code and throw an appropriate runtime_error otherwise.
48+
void* stream_ptr = nullptr;
49+
TORCH_ERROR_CODE_CHECK(
50+
aoti_torch_get_current_cuda_stream(a.get_device_index(), &stream_ptr));
51+
cudaStream_t stream = static_cast<cudaStream_t>(stream_ptr);
52+
3153
muladd_kernel<<<(numel+255)/256, 256, 0, stream>>>(numel, a_ptr, b_ptr, c, result_ptr);
3254
return result;
3355
}
@@ -37,20 +59,34 @@ __global__ void mul_kernel(int numel, const float* a, const float* b, float* res
3759
if (idx < numel) result[idx] = a[idx] * b[idx];
3860
}
3961

40-
at::Tensor mymul_cuda(const at::Tensor& a, const at::Tensor& b) {
41-
TORCH_CHECK(a.sizes() == b.sizes());
42-
TORCH_CHECK(a.dtype() == at::kFloat);
43-
TORCH_CHECK(b.dtype() == at::kFloat);
44-
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CUDA);
45-
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CUDA);
46-
at::Tensor a_contig = a.contiguous();
47-
at::Tensor b_contig = b.contiguous();
48-
at::Tensor result = at::empty(a_contig.sizes(), a_contig.options());
49-
const float* a_ptr = a_contig.data_ptr<float>();
50-
const float* b_ptr = b_contig.data_ptr<float>();
51-
float* result_ptr = result.data_ptr<float>();
62+
torch::stable::Tensor mymul_cuda(
63+
const torch::stable::Tensor& a,
64+
const torch::stable::Tensor& b) {
65+
STD_TORCH_CHECK(a.sizes().equals(b.sizes()));
66+
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Float);
67+
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Float);
68+
STD_TORCH_CHECK(a.device().type() == torch::headeronly::DeviceType::CUDA);
69+
STD_TORCH_CHECK(b.device().type() == torch::headeronly::DeviceType::CUDA);
70+
71+
torch::stable::Tensor a_contig = torch::stable::contiguous(a);
72+
torch::stable::Tensor b_contig = torch::stable::contiguous(b);
73+
torch::stable::Tensor result = torch::stable::empty_like(a_contig);
74+
75+
const float* a_ptr = a_contig.const_data_ptr<float>();
76+
const float* b_ptr = b_contig.const_data_ptr<float>();
77+
float* result_ptr = result.mutable_data_ptr<float>();
78+
5279
int numel = a_contig.numel();
53-
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
80+
81+
// For now, we rely on the raw shim API to get the current CUDA stream.
82+
// This will be improved in a future release.
83+
// When using a raw shim API, we need to use TORCH_ERROR_CODE_CHECK to
84+
// check the error code and throw an appropriate runtime_error otherwise.
85+
void* stream_ptr = nullptr;
86+
TORCH_ERROR_CODE_CHECK(
87+
aoti_torch_get_current_cuda_stream(a.get_device_index(), &stream_ptr));
88+
cudaStream_t stream = static_cast<cudaStream_t>(stream_ptr);
89+
5490
mul_kernel<<<(numel+255)/256, 256, 0, stream>>>(numel, a_ptr, b_ptr, result_ptr);
5591
return result;
5692
}
@@ -60,32 +96,47 @@ __global__ void add_kernel(int numel, const float* a, const float* b, float* res
6096
if (idx < numel) result[idx] = a[idx] + b[idx];
6197
}
6298

63-
void myadd_out_cuda(const at::Tensor& a, const at::Tensor& b, at::Tensor& out) {
64-
TORCH_CHECK(a.sizes() == b.sizes());
65-
TORCH_CHECK(b.sizes() == out.sizes());
66-
TORCH_CHECK(a.dtype() == at::kFloat);
67-
TORCH_CHECK(b.dtype() == at::kFloat);
68-
TORCH_CHECK(out.dtype() == at::kFloat);
69-
TORCH_CHECK(out.is_contiguous());
70-
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CUDA);
71-
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CUDA);
72-
TORCH_INTERNAL_ASSERT(out.device().type() == at::DeviceType::CUDA);
73-
at::Tensor a_contig = a.contiguous();
74-
at::Tensor b_contig = b.contiguous();
75-
const float* a_ptr = a_contig.data_ptr<float>();
76-
const float* b_ptr = b_contig.data_ptr<float>();
77-
float* result_ptr = out.data_ptr<float>();
99+
// An example of an operator that mutates one of its inputs.
100+
void myadd_out_cuda(
101+
const torch::stable::Tensor& a,
102+
const torch::stable::Tensor& b,
103+
torch::stable::Tensor& out) {
104+
STD_TORCH_CHECK(a.sizes().equals(b.sizes()));
105+
STD_TORCH_CHECK(b.sizes().equals(out.sizes()));
106+
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Float);
107+
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Float);
108+
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Float);
109+
STD_TORCH_CHECK(out.is_contiguous());
110+
STD_TORCH_CHECK(a.device().type() == torch::headeronly::DeviceType::CUDA);
111+
STD_TORCH_CHECK(b.device().type() == torch::headeronly::DeviceType::CUDA);
112+
STD_TORCH_CHECK(out.device().type() == torch::headeronly::DeviceType::CUDA);
113+
114+
torch::stable::Tensor a_contig = torch::stable::contiguous(a);
115+
torch::stable::Tensor b_contig = torch::stable::contiguous(b);
116+
117+
const float* a_ptr = a_contig.const_data_ptr<float>();
118+
const float* b_ptr = b_contig.const_data_ptr<float>();
119+
float* result_ptr = out.mutable_data_ptr<float>();
120+
78121
int numel = a_contig.numel();
79-
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
122+
123+
// For now, we rely on the raw shim API to get the current CUDA stream.
124+
// This will be improved in a future release.
125+
// When using a raw shim API, we need to use TORCH_ERROR_CODE_CHECK to
126+
// check the error code and throw an appropriate runtime_error otherwise.
127+
void* stream_ptr = nullptr;
128+
TORCH_ERROR_CODE_CHECK(
129+
aoti_torch_get_current_cuda_stream(a.get_device_index(), &stream_ptr));
130+
cudaStream_t stream = static_cast<cudaStream_t>(stream_ptr);
131+
80132
add_kernel<<<(numel+255)/256, 256, 0, stream>>>(numel, a_ptr, b_ptr, result_ptr);
81133
}
82134

83-
84135
// Registers CUDA implementations for mymuladd, mymul, myadd_out
85-
TORCH_LIBRARY_IMPL(extension_cpp, CUDA, m) {
86-
m.impl("mymuladd", &mymuladd_cuda);
87-
m.impl("mymul", &mymul_cuda);
88-
m.impl("myadd_out", &myadd_out_cuda);
136+
STABLE_TORCH_LIBRARY_IMPL(extension_cpp, CUDA, m) {
137+
m.impl("mymuladd", TORCH_BOX(&mymuladd_cuda));
138+
m.impl("mymul", TORCH_BOX(&mymul_cuda));
139+
m.impl("myadd_out", TORCH_BOX(&myadd_out_cuda));
89140
}
90141

91142
}

extension_cpp/csrc/muladd.cpp

Lines changed: 71 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
1+
// LibTorch Stable ABI version of custom operators
2+
// This file uses the stable API for cross-version compatibility.
3+
// See: https://pytorch.org/docs/main/notes/libtorch_stable_abi.html
4+
15
#include <Python.h>
2-
#include <ATen/Operators.h>
3-
#include <torch/all.h>
4-
#include <torch/library.h>
56

6-
#include <vector>
7+
#include <torch/csrc/stable/library.h>
8+
#include <torch/csrc/stable/ops.h>
9+
#include <torch/csrc/stable/tensor.h>
10+
#include <torch/headeronly/core/ScalarType.h>
11+
#include <torch/headeronly/macros/Macros.h>
712

813
extern "C" {
914
/* Creates a dummy empty _C module that can be imported from Python.
1015
The import from Python will load the .so consisting of this file
11-
in this extension, so that the TORCH_LIBRARY static initializers
16+
in this extension, so that the STABLE_TORCH_LIBRARY static initializers
1217
below are run. */
1318
PyObject* PyInit__C(void)
1419
{
@@ -26,75 +31,92 @@ extern "C" {
2631

2732
namespace extension_cpp {
2833

29-
at::Tensor mymuladd_cpu(const at::Tensor& a, const at::Tensor& b, double c) {
30-
TORCH_CHECK(a.sizes() == b.sizes());
31-
TORCH_CHECK(a.dtype() == at::kFloat);
32-
TORCH_CHECK(b.dtype() == at::kFloat);
33-
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU);
34-
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU);
35-
at::Tensor a_contig = a.contiguous();
36-
at::Tensor b_contig = b.contiguous();
37-
at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options());
38-
const float* a_ptr = a_contig.data_ptr<float>();
39-
const float* b_ptr = b_contig.data_ptr<float>();
40-
float* result_ptr = result.data_ptr<float>();
34+
torch::stable::Tensor mymuladd_cpu(
35+
const torch::stable::Tensor& a,
36+
const torch::stable::Tensor& b,
37+
double c) {
38+
STD_TORCH_CHECK(a.sizes().equals(b.sizes()));
39+
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Float);
40+
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Float);
41+
STD_TORCH_CHECK(a.device().type() == torch::headeronly::DeviceType::CPU);
42+
STD_TORCH_CHECK(b.device().type() == torch::headeronly::DeviceType::CPU);
43+
44+
torch::stable::Tensor a_contig = torch::stable::contiguous(a);
45+
torch::stable::Tensor b_contig = torch::stable::contiguous(b);
46+
torch::stable::Tensor result = torch::stable::empty_like(a_contig);
47+
48+
const float* a_ptr = a_contig.const_data_ptr<float>();
49+
const float* b_ptr = b_contig.const_data_ptr<float>();
50+
float* result_ptr = result.mutable_data_ptr<float>();
51+
4152
for (int64_t i = 0; i < result.numel(); i++) {
4253
result_ptr[i] = a_ptr[i] * b_ptr[i] + c;
4354
}
4455
return result;
4556
}
4657

47-
at::Tensor mymul_cpu(const at::Tensor& a, const at::Tensor& b) {
48-
TORCH_CHECK(a.sizes() == b.sizes());
49-
TORCH_CHECK(a.dtype() == at::kFloat);
50-
TORCH_CHECK(b.dtype() == at::kFloat);
51-
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU);
52-
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU);
53-
at::Tensor a_contig = a.contiguous();
54-
at::Tensor b_contig = b.contiguous();
55-
at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options());
56-
const float* a_ptr = a_contig.data_ptr<float>();
57-
const float* b_ptr = b_contig.data_ptr<float>();
58-
float* result_ptr = result.data_ptr<float>();
58+
torch::stable::Tensor mymul_cpu(
59+
const torch::stable::Tensor& a,
60+
const torch::stable::Tensor& b) {
61+
STD_TORCH_CHECK(a.sizes().equals(b.sizes()));
62+
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Float);
63+
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Float);
64+
STD_TORCH_CHECK(a.device().type() == torch::headeronly::DeviceType::CPU);
65+
STD_TORCH_CHECK(b.device().type() == torch::headeronly::DeviceType::CPU);
66+
67+
torch::stable::Tensor a_contig = torch::stable::contiguous(a);
68+
torch::stable::Tensor b_contig = torch::stable::contiguous(b);
69+
torch::stable::Tensor result = torch::stable::empty_like(a_contig);
70+
71+
const float* a_ptr = a_contig.const_data_ptr<float>();
72+
const float* b_ptr = b_contig.const_data_ptr<float>();
73+
float* result_ptr = result.mutable_data_ptr<float>();
74+
5975
for (int64_t i = 0; i < result.numel(); i++) {
6076
result_ptr[i] = a_ptr[i] * b_ptr[i];
6177
}
6278
return result;
6379
}
6480

6581
// An example of an operator that mutates one of its inputs.
66-
void myadd_out_cpu(const at::Tensor& a, const at::Tensor& b, at::Tensor& out) {
67-
TORCH_CHECK(a.sizes() == b.sizes());
68-
TORCH_CHECK(b.sizes() == out.sizes());
69-
TORCH_CHECK(a.dtype() == at::kFloat);
70-
TORCH_CHECK(b.dtype() == at::kFloat);
71-
TORCH_CHECK(out.dtype() == at::kFloat);
72-
TORCH_CHECK(out.is_contiguous());
73-
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU);
74-
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU);
75-
TORCH_INTERNAL_ASSERT(out.device().type() == at::DeviceType::CPU);
76-
at::Tensor a_contig = a.contiguous();
77-
at::Tensor b_contig = b.contiguous();
78-
const float* a_ptr = a_contig.data_ptr<float>();
79-
const float* b_ptr = b_contig.data_ptr<float>();
80-
float* result_ptr = out.data_ptr<float>();
82+
void myadd_out_cpu(
83+
const torch::stable::Tensor& a,
84+
const torch::stable::Tensor& b,
85+
torch::stable::Tensor& out) {
86+
STD_TORCH_CHECK(a.sizes().equals(b.sizes()));
87+
STD_TORCH_CHECK(b.sizes().equals(out.sizes()));
88+
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Float);
89+
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Float);
90+
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Float);
91+
STD_TORCH_CHECK(out.is_contiguous());
92+
STD_TORCH_CHECK(a.device().type() == torch::headeronly::DeviceType::CPU);
93+
STD_TORCH_CHECK(b.device().type() == torch::headeronly::DeviceType::CPU);
94+
STD_TORCH_CHECK(out.device().type() == torch::headeronly::DeviceType::CPU);
95+
96+
torch::stable::Tensor a_contig = torch::stable::contiguous(a);
97+
torch::stable::Tensor b_contig = torch::stable::contiguous(b);
98+
99+
const float* a_ptr = a_contig.const_data_ptr<float>();
100+
const float* b_ptr = b_contig.const_data_ptr<float>();
101+
float* result_ptr = out.mutable_data_ptr<float>();
102+
81103
for (int64_t i = 0; i < out.numel(); i++) {
82104
result_ptr[i] = a_ptr[i] + b_ptr[i];
83105
}
84106
}
85107

86108
// Defines the operators
87-
TORCH_LIBRARY(extension_cpp, m) {
109+
STABLE_TORCH_LIBRARY(extension_cpp, m) {
88110
m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor");
89111
m.def("mymul(Tensor a, Tensor b) -> Tensor");
90112
m.def("myadd_out(Tensor a, Tensor b, Tensor(a!) out) -> ()");
91113
}
92114

93115
// Registers CPU implementations for mymuladd, mymul, myadd_out
94-
TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) {
95-
m.impl("mymuladd", &mymuladd_cpu);
96-
m.impl("mymul", &mymul_cpu);
97-
m.impl("myadd_out", &myadd_out_cpu);
116+
STABLE_TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) {
117+
m.impl("mymuladd", TORCH_BOX(&mymuladd_cpu));
118+
m.impl("mymul", TORCH_BOX(&mymul_cpu));
119+
m.impl("myadd_out", TORCH_BOX(&myadd_out_cpu));
98120
}
99121

100122
}

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[build-system]
22
requires = [
33
"setuptools",
4-
"torch",
4+
"torch>=2.10.0",
55
]
66
build-backend = "setuptools.build_meta"

0 commit comments

Comments
 (0)