Skip to content
49 changes: 33 additions & 16 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5085,10 +5085,10 @@ def aten_linear_backward(

@torch_op("aten::linspace", trace_only=True)
def aten_linspace(
start: TFloat,
end: TFloat,
start: float,
end: float,
steps: int,
dtype: int = FLOAT.dtype,
dtype: int = -1,
layout: str = "",
device: str = "",
pin_memory: bool = False,
Expand All @@ -5098,26 +5098,43 @@ def aten_linspace(
if dtype == -1 or dtype is None:
dtype = FLOAT.dtype

# Reference: https://github.com/pytorch/pytorch/blob/b35ca2cb941b5ba90858322810ca85c31e4541fd/torch/_refs/__init__.py#L4896
if steps == 0:
return aten_full(op.Constant(value_ints=[0]), 0.0, dtype=dtype)
if steps == 1:
return aten_full(op.Constant(value_ints=[steps]), start, dtype=dtype)

rg = aten_arange_start(0, steps, dtype=dtype)
start = op.Cast(start, to=dtype)
end = op.Cast(end, to=dtype)
steps_float = op.Cast(steps, to=dtype)
one = op.Cast(1.0, to=dtype)
two = op.Cast(2.0, to=dtype)
steps_minus_1 = op.Cast(steps - 1, to=dtype)
step = op.Div(op.Sub(end, start), steps_minus_1)
return op.Where(
rg < op.Div(steps_float, two),
start + step * rg,
end - step * (steps_float - one - rg),
# For integer output dtypes, cast start/end to the target dtype first
# This matches PyTorch's behavior where fractional start/end values
# are truncated before computing the linspace
if ir.DataType(dtype).is_integer():
# Use double precision for computation to match PyTorch's internal precision
compute_dtype = ir.DataType.DOUBLE
# Cast to integer dtype first, then to compute dtype
# This ensures truncation happens before computation
start_f = op.Constant(value=ir.tensor(int(start), dtype=compute_dtype))
end_f = op.Constant(value=ir.tensor(int(end), dtype=compute_dtype))
else:
compute_dtype = dtype
start_f = op.Constant(value=ir.tensor(start, dtype=compute_dtype))
end_f = op.Constant(value=ir.tensor(end, dtype=compute_dtype))

rg = aten_arange_start(0, steps, dtype=compute_dtype)
steps_f = op.Constant(value=ir.tensor(steps, dtype=compute_dtype))
one = op.Constant(value=ir.tensor(1, dtype=compute_dtype))
two = op.Constant(value=ir.tensor(2, dtype=compute_dtype))
steps_minus_1 = op.Constant(value=ir.tensor(steps - 1, dtype=compute_dtype))
step = op.Constant(value=ir.tensor((end - start) / (steps - 1), dtype=compute_dtype))

# Two-sided computation for numerical stability at endpoints
# Use forward computation for first half, backward for second half
lin_vals = op.Where(
rg < op.Div(steps_f, two),
op.Add(start_f, op.Mul(step, rg)),
op.Sub(end_f, op.Mul(step, op.Sub(steps_minus_1, rg))),
)

return op.Cast(lin_vals, to=dtype)


@torch_op("aten::log", trace_only=True)
def aten_log(self: TFloat) -> TFloat:
Expand Down
8 changes: 0 additions & 8 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,14 +777,6 @@ def _where_input_wrangler(
"linspace",
core_ops.aten_linspace,
tolerance={torch.float16: (2e-2, 2e-3)},
)
.xfail(
dtypes=(torch.int64, torch.int32),
reason="fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854",
)
.skip(
matcher=lambda sample: sample.kwargs.get("dtype") in (torch.int64, torch.int32),
reason="fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854",
),
TorchLibOpInfo("log", core_ops.aten_log),
TorchLibOpInfo("le", core_ops.aten_le),
Expand Down
Loading