diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index b287cec057..594c9d9465 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -5085,10 +5085,10 @@ def aten_linear_backward( @torch_op("aten::linspace", trace_only=True) def aten_linspace( - start: TFloat, - end: TFloat, + start: float, + end: float, steps: int, - dtype: int = FLOAT.dtype, + dtype: int = -1, layout: str = "", device: str = "", pin_memory: bool = False, @@ -5098,26 +5098,43 @@ def aten_linspace( if dtype == -1 or dtype is None: dtype = FLOAT.dtype - # Reference: https://github.com/pytorch/pytorch/blob/b35ca2cb941b5ba90858322810ca85c31e4541fd/torch/_refs/__init__.py#L4896 if steps == 0: return aten_full(op.Constant(value_ints=[0]), 0.0, dtype=dtype) if steps == 1: return aten_full(op.Constant(value_ints=[steps]), start, dtype=dtype) - rg = aten_arange_start(0, steps, dtype=dtype) - start = op.Cast(start, to=dtype) - end = op.Cast(end, to=dtype) - steps_float = op.Cast(steps, to=dtype) - one = op.Cast(1.0, to=dtype) - two = op.Cast(2.0, to=dtype) - steps_minus_1 = op.Cast(steps - 1, to=dtype) - step = op.Div(op.Sub(end, start), steps_minus_1) - return op.Where( - rg < op.Div(steps_float, two), - start + step * rg, - end - step * (steps_float - one - rg), + # For integer output dtypes, cast start/end to the target dtype first + # This matches PyTorch's behavior where fractional start/end values + # are truncated before computing the linspace + if ir.DataType(dtype).is_integer(): + # Use double precision for computation to match PyTorch's internal precision + compute_dtype = ir.DataType.DOUBLE + # Cast to integer dtype first, then to compute dtype + # This ensures truncation happens before computation + start_f = op.Constant(value=ir.tensor(int(start), dtype=compute_dtype)) + end_f = op.Constant(value=ir.tensor(int(end), dtype=compute_dtype)) + else: + compute_dtype = dtype + start_f = op.Constant(value=ir.tensor(start, dtype=compute_dtype)) + end_f = op.Constant(value=ir.tensor(end, dtype=compute_dtype)) + + rg = aten_arange_start(0, steps, dtype=compute_dtype) + steps_f = op.Constant(value=ir.tensor(steps, dtype=compute_dtype)) + one = op.Constant(value=ir.tensor(1, dtype=compute_dtype)) + two = op.Constant(value=ir.tensor(2, dtype=compute_dtype)) + steps_minus_1 = op.Constant(value=ir.tensor(steps - 1, dtype=compute_dtype)) + step = op.Constant(value=ir.tensor((end - start) / (steps - 1), dtype=compute_dtype)) + + # Two-sided computation for numerical stability at endpoints + # Use forward computation for first half, backward for second half + lin_vals = op.Where( + rg < op.Div(steps_f, two), + op.Add(start_f, op.Mul(step, rg)), + op.Sub(end_f, op.Mul(step, op.Sub(steps_minus_1, rg))), ) + return op.Cast(lin_vals, to=dtype) + @torch_op("aten::log", trace_only=True) def aten_log(self: TFloat) -> TFloat: diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index e87a0cc232..93c2d1045d 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -777,14 +777,6 @@ def _where_input_wrangler( "linspace", core_ops.aten_linspace, tolerance={torch.float16: (2e-2, 2e-3)}, - ) - .xfail( - dtypes=(torch.int64, torch.int32), - reason="fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854", - ) - .skip( - matcher=lambda sample: sample.kwargs.get("dtype") in (torch.int64, torch.int32), - reason="fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854", ), TorchLibOpInfo("log", core_ops.aten_log), TorchLibOpInfo("le", core_ops.aten_le),