From 974f6632268b4dacce81c2b5bcbde0de9cd165b5 Mon Sep 17 00:00:00 2001 From: Aravind-11 Date: Sat, 8 Nov 2025 19:37:25 -0700 Subject: [PATCH 1/7] Fixes #854 - linspace now correctly handles int64 dtype --- .../function_libs/torch_lib/ops/core.py | 31 ++++++++++--------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 2cbecdcfc2..d3254a4276 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4984,30 +4984,33 @@ def aten_linspace( pin_memory: bool = False, ) -> TensorType: """linspace(Scalar start, Scalar end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" - if dtype == -1 or dtype is None: dtype = FLOAT.dtype - # Reference: https://github.com/pytorch/pytorch/blob/b35ca2cb941b5ba90858322810ca85c31e4541fd/torch/_refs/__init__.py#L4896 if steps == 0: return aten_full(op.Constant(value_ints=[0]), 0.0, dtype=dtype) if steps == 1: return aten_full(op.Constant(value_ints=[steps]), start, dtype=dtype) - rg = aten_arange_start(0, steps, dtype=dtype) - start = op.Cast(start, to=dtype) - end = op.Cast(end, to=dtype) - steps_float = op.Cast(steps, to=dtype) - one = op.Cast(1.0, to=dtype) - two = op.Cast(2.0, to=dtype) - steps_minus_1 = op.Cast(steps - 1, to=dtype) - step = op.Div(op.Sub(end, start), steps_minus_1) - return op.Where( - rg < op.Div(steps_float, two), - start + step * rg, - end - step * (steps_float - one - rg), + compute_dtype = FLOAT.dtype + + rg = aten_arange_start(0, steps, dtype=compute_dtype) + start_f = op.Cast(start, to=compute_dtype) + end_f = op.Cast(end, to=compute_dtype) + steps_f = op.Cast(steps, to=compute_dtype) + one = op.Cast(1.0, to=compute_dtype) + two = op.Cast(2.0, to=compute_dtype) + steps_minus_1 = op.Sub(steps_f, one) + step = op.Div(op.Sub(end_f, start_f), steps_minus_1) + + lin_vals = op.Where( + rg < op.Div(steps_f, two), + op.Add(start_f, op.Mul(step, rg)), + op.Sub(end_f, op.Mul(step, op.Sub(op.Sub(steps_f, one), rg))), ) + return op.Cast(lin_vals, to=dtype) + @torch_op("aten::log", trace_only=True) def aten_log(self: TFloat) -> TFloat: From 4112ce92df75683e95d82f1b6cda328040c2cc7f Mon Sep 17 00:00:00 2001 From: Aravind-11 Date: Sun, 16 Nov 2025 18:05:07 -0700 Subject: [PATCH 2/7] unskip test --- tests/function_libs/torch_lib/ops_test_data.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index b60fd8cf31..5bdd371b10 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -767,14 +767,6 @@ def _where_input_wrangler( core_ops.aten_linspace, tolerance={torch.float16: (2e-2, 2e-3)}, ) - .xfail( - dtypes=(torch.int64, torch.int32), - reason="fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854", - ) - .skip( - matcher=lambda sample: sample.kwargs.get("dtype") in (torch.int64, torch.int32), - reason="fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854", - ), TorchLibOpInfo("log", core_ops.aten_log), TorchLibOpInfo("le", core_ops.aten_le), TorchLibOpInfo("lerp", core_ops.aten_lerp, tolerance={torch.float16: (2e-3, 2e-1)}), From 0129641c86ead43e4fd8cb271368d321522ef785 Mon Sep 17 00:00:00 2001 From: ARAVINDHAN T Date: Sun, 16 Nov 2025 17:31:28 -0800 Subject: [PATCH 3/7] Fix syntax error in ops_test_data.py --- tests/function_libs/torch_lib/ops_test_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 1a941175de..ab9e7a5793 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -766,7 +766,7 @@ def _where_input_wrangler( "linspace", core_ops.aten_linspace, tolerance={torch.float16: (2e-2, 2e-3)}, - ) + ), TorchLibOpInfo("log", core_ops.aten_log), TorchLibOpInfo("le", core_ops.aten_le), TorchLibOpInfo("lerp", core_ops.aten_lerp, tolerance={torch.float16: (2e-3, 2e-1)}), From c772664f3a29306ec737acd8dd4e5872a2091ddf Mon Sep 17 00:00:00 2001 From: Aravind-11 Date: Thu, 11 Dec 2025 00:59:37 -0700 Subject: [PATCH 4/7] fixes --- .../function_libs/torch_lib/ops/core.py | 43 +++++++++++++------ 1 file changed, 31 insertions(+), 12 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 2ac3f56d7e..c836f7cb09 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -5085,40 +5085,59 @@ def aten_linear_backward( @torch_op("aten::linspace", trace_only=True) def aten_linspace( - start: TFloat, - end: TFloat, + start: float, + end: float, steps: int, - dtype: int = FLOAT.dtype, + dtype: int = -1, layout: str = "", device: str = "", pin_memory: bool = False, ) -> TensorType: """linspace(Scalar start, Scalar end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" + if dtype == -1 or dtype is None: dtype = FLOAT.dtype - + if steps == 0: return aten_full(op.Constant(value_ints=[0]), 0.0, dtype=dtype) if steps == 1: return aten_full(op.Constant(value_ints=[steps]), start, dtype=dtype) - - compute_dtype = FLOAT.dtype - + + # Use double precision for computation to match PyTorch's internal precision + compute_dtype = DOUBLE.dtype + + # For integer output dtypes, cast start/end to the target dtype first + # This matches PyTorch's behavior where fractional start/end values + # are truncated before computing the linspace + is_integer_dtype = dtype not in (FLOAT.dtype, DOUBLE.dtype, FLOAT16.dtype, COMPLEX64.dtype, COMPLEX128.dtype) + + if is_integer_dtype: + # Cast to integer dtype first, then to compute dtype + # This ensures truncation happens before computation + start_int = op.Cast(start, to=dtype) + end_int = op.Cast(end, to=dtype) + start_f = op.Cast(start_int, to=compute_dtype) + end_f = op.Cast(end_int, to=compute_dtype) + else: + # For float dtypes, cast directly to compute dtype + start_f = op.Cast(start, to=compute_dtype) + end_f = op.Cast(end, to=compute_dtype) + rg = aten_arange_start(0, steps, dtype=compute_dtype) - start_f = op.Cast(start, to=compute_dtype) - end_f = op.Cast(end, to=compute_dtype) steps_f = op.Cast(steps, to=compute_dtype) one = op.Cast(1.0, to=compute_dtype) two = op.Cast(2.0, to=compute_dtype) steps_minus_1 = op.Sub(steps_f, one) step = op.Div(op.Sub(end_f, start_f), steps_minus_1) - + + # Two-sided computation for numerical stability at endpoints + # Use forward computation for first half, backward for second half lin_vals = op.Where( rg < op.Div(steps_f, two), op.Add(start_f, op.Mul(step, rg)), - op.Sub(end_f, op.Mul(step, op.Sub(op.Sub(steps_f, one), rg))), + op.Sub(end_f, op.Mul(step, op.Sub(steps_minus_1, rg))), ) - + return op.Cast(lin_vals, to=dtype) From 1e37116c7525d7854b762a7a262d1ffe1e1e84fe Mon Sep 17 00:00:00 2001 From: Aravind-11 Date: Thu, 11 Dec 2025 11:53:47 -0700 Subject: [PATCH 5/7] linting --- .../function_libs/torch_lib/ops/core.py | 24 ++++++++++++------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index c836f7cb09..0c6459ec53 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -5094,23 +5094,29 @@ def aten_linspace( pin_memory: bool = False, ) -> TensorType: """linspace(Scalar start, Scalar end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" - + if dtype == -1 or dtype is None: dtype = FLOAT.dtype - + if steps == 0: return aten_full(op.Constant(value_ints=[0]), 0.0, dtype=dtype) if steps == 1: return aten_full(op.Constant(value_ints=[steps]), start, dtype=dtype) - + # Use double precision for computation to match PyTorch's internal precision compute_dtype = DOUBLE.dtype - + # For integer output dtypes, cast start/end to the target dtype first # This matches PyTorch's behavior where fractional start/end values # are truncated before computing the linspace - is_integer_dtype = dtype not in (FLOAT.dtype, DOUBLE.dtype, FLOAT16.dtype, COMPLEX64.dtype, COMPLEX128.dtype) - + is_integer_dtype = dtype not in ( + FLOAT.dtype, + DOUBLE.dtype, + FLOAT16.dtype, + COMPLEX64.dtype, + COMPLEX128.dtype, + ) + if is_integer_dtype: # Cast to integer dtype first, then to compute dtype # This ensures truncation happens before computation @@ -5122,14 +5128,14 @@ def aten_linspace( # For float dtypes, cast directly to compute dtype start_f = op.Cast(start, to=compute_dtype) end_f = op.Cast(end, to=compute_dtype) - + rg = aten_arange_start(0, steps, dtype=compute_dtype) steps_f = op.Cast(steps, to=compute_dtype) one = op.Cast(1.0, to=compute_dtype) two = op.Cast(2.0, to=compute_dtype) steps_minus_1 = op.Sub(steps_f, one) step = op.Div(op.Sub(end_f, start_f), steps_minus_1) - + # Two-sided computation for numerical stability at endpoints # Use forward computation for first half, backward for second half lin_vals = op.Where( @@ -5137,7 +5143,7 @@ def aten_linspace( op.Add(start_f, op.Mul(step, rg)), op.Sub(end_f, op.Mul(step, op.Sub(steps_minus_1, rg))), ) - + return op.Cast(lin_vals, to=dtype) From 52a9172582355852bd23ba312814e1377b2f1674 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 15 Dec 2025 13:00:51 -0800 Subject: [PATCH 6/7] Refactor casting logic for arange function --- .../function_libs/torch_lib/ops/core.py | 20 +++++++++---------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 0c6459ec53..1cca05219a 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -5117,24 +5117,22 @@ def aten_linspace( COMPLEX128.dtype, ) - if is_integer_dtype: + if ir.DataType(dtype).is_integer(): # Cast to integer dtype first, then to compute dtype # This ensures truncation happens before computation start_int = op.Cast(start, to=dtype) end_int = op.Cast(end, to=dtype) - start_f = op.Cast(start_int, to=compute_dtype) - end_f = op.Cast(end_int, to=compute_dtype) + start = op.Cast(start_int, to=compute_dtype) + end = op.Cast(end_int, to=compute_dtype) else: - # For float dtypes, cast directly to compute dtype - start_f = op.Cast(start, to=compute_dtype) - end_f = op.Cast(end, to=compute_dtype) + compute_dtype = dtype rg = aten_arange_start(0, steps, dtype=compute_dtype) - steps_f = op.Cast(steps, to=compute_dtype) - one = op.Cast(1.0, to=compute_dtype) - two = op.Cast(2.0, to=compute_dtype) - steps_minus_1 = op.Sub(steps_f, one) - step = op.Div(op.Sub(end_f, start_f), steps_minus_1) + steps_f = op.Constant(value=ir.tensor(steps, dtype=compute_dtype)) + one = op.Constant(value=ir.tensor(1, dtype=compute_dtype)) + two = op.Constant(value=ir.tensor(2, dtype=compute_dtype)) + steps_minus_1 = op.Constant(value=ir.tensor(steps - 1, dtype=compute_dtype)) + step = op.Constant(value=ir.tensor((end - start) / (steps - 1), dtype=compute_dtype)) # Two-sided computation for numerical stability at endpoints # Use forward computation for first half, backward for second half From beacf269afdf12503654040c678581a632c83b6e Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 15 Dec 2025 13:05:27 -0800 Subject: [PATCH 7/7] Update core.py --- .../function_libs/torch_lib/ops/core.py | 21 ++++++------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 1cca05219a..594c9d9465 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -5103,29 +5103,20 @@ def aten_linspace( if steps == 1: return aten_full(op.Constant(value_ints=[steps]), start, dtype=dtype) - # Use double precision for computation to match PyTorch's internal precision - compute_dtype = DOUBLE.dtype - # For integer output dtypes, cast start/end to the target dtype first # This matches PyTorch's behavior where fractional start/end values # are truncated before computing the linspace - is_integer_dtype = dtype not in ( - FLOAT.dtype, - DOUBLE.dtype, - FLOAT16.dtype, - COMPLEX64.dtype, - COMPLEX128.dtype, - ) - if ir.DataType(dtype).is_integer(): + # Use double precision for computation to match PyTorch's internal precision + compute_dtype = ir.DataType.DOUBLE # Cast to integer dtype first, then to compute dtype # This ensures truncation happens before computation - start_int = op.Cast(start, to=dtype) - end_int = op.Cast(end, to=dtype) - start = op.Cast(start_int, to=compute_dtype) - end = op.Cast(end_int, to=compute_dtype) + start_f = op.Constant(value=ir.tensor(int(start), dtype=compute_dtype)) + end_f = op.Constant(value=ir.tensor(int(end), dtype=compute_dtype)) else: compute_dtype = dtype + start_f = op.Constant(value=ir.tensor(start, dtype=compute_dtype)) + end_f = op.Constant(value=ir.tensor(end, dtype=compute_dtype)) rg = aten_arange_start(0, steps, dtype=compute_dtype) steps_f = op.Constant(value=ir.tensor(steps, dtype=compute_dtype))