Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 55 additions & 30 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2861,6 +2861,7 @@ def aten_embedding_bag(
sparse: bool = False,
per_sample_weights: Optional[TFloat] = None,
include_last_offset: bool = False,
padding_idx: Optional[int] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked https://github.com/pytorch/pytorch/blob/8656dea039bd0a31952a6a9792566e70b07429dc/aten/src/ATen/native/native_functions.yaml#L2372-L2378 and found that embedding_bag does not have padding_idx as a parameter. I think you only need to update aten_embedding_bag_padding_idx

) -> Tuple[TFloat, TFloat, TFloat, TFloat]:
"""embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> (Tensor, Tensor, Tensor, Tensor)"""

Expand Down Expand Up @@ -2957,23 +2958,24 @@ def _aten_embedding_bag_onnx(

# Only compute the shape of other 3 outputs, we don't care the value
if mode == 0: # sum
offset2bag = op.Shape(indices, start=0, end=0) # Generate empty tensor
offset2bag = op.Cast(op.Shape(indices, start=0, end=0), to=INT64.dtype)
if op.Equal(include_last_offset, True):
bag_size = op.Expand(0, op.Shape(offsets))
bag_size = op.Cast(op.Expand(0, op.Shape(offsets)), to=INT64.dtype)
max_indices = op.Cast(op.Expand(0, op.Shape(offsets)), to=INT64.dtype)
else:
bag_size = op.Expand(0, op.Shape(offsets) - 1)
max_indices = op.Expand(0, op.Shape(bag_size))
bag_size = op.Cast(op.Expand(0, op.Shape(offsets) - 1), to=INT64.dtype)
max_indices = op.Cast(op.Expand(0, op.Shape(offsets) - 1), to=INT64.dtype)
elif mode == 1: # mean
offset2bag = op.Expand(0, op.Shape(indices, start=0, end=1))
bag_size = op.Expand(0, op.Shape(offsets) - 1)
max_indices = op.Expand(0, op.Shape(bag_size))
offset2bag = op.Cast(op.Expand(0, op.Shape(indices, start=0, end=1)), to=INT64.dtype)
bag_size = op.Cast(op.Expand(0, op.Shape(offsets) - 1), to=INT64.dtype)
max_indices = op.Cast(op.Expand(0, op.Shape(offsets) - 1), to=INT64.dtype)
else: # max
offset2bag = op.Expand(0, op.Shape(indices, start=0, end=1))
bag_size = op.Expand(0, op.Shape(offsets) - 1)
offset2bag = op.Cast(op.Expand(0, op.Shape(indices, start=0, end=1)), to=INT64.dtype)
bag_size = op.Cast(op.Expand(0, op.Shape(offsets) - 1), to=INT64.dtype)
# shape = (bag_size.dim[0], weight.dim[1])
dim_0 = op.Shape(bag_size, start=0, end=1)
dim_1 = op.Shape(weight, start=1, end=2)
max_indices = op.Expand(0, op.Concat(dim_0, dim_1, axis=0))
max_indices = op.Cast(op.Expand(0, op.Concat(dim_0, dim_1, axis=0)), to=INT64.dtype)

return result, offset2bag, bag_size, max_indices

Expand All @@ -2995,27 +2997,40 @@ def aten_embedding_bag_padding_idx(
sparse: bool = False,
per_sample_weights: Optional[TFloat] = None,
include_last_offset: bool = False,
padding_idx: int = -1,
padding_idx: Optional[int] = None,
) -> Tuple[TFloat, TFloat, TFloat, TFloat]:
"""embedding_bag.padding_idx(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, bool include_last_offset, int? padding_idx) -> (Tensor, Tensor, Tensor, Tensor)

We add default values for the attributes to accommodate _embedding_bag as well:
_embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1)
"""
assert padding_idx is not None, (
"padding_idx must not be None. This is likely a dispatcher error"
)

if per_sample_weights is None:
per_sample_weights = op.Expand(op.Constant(value_floats=[1.0]), op.Shape(indices))
per_sample_weights = op.CastLike(per_sample_weights, weight)

# Change padding_idx to positive value, -1 means the last index
if padding_idx < 0:
padding_idx = weight.shape[0] + padding_idx
if padding_idx is not None:
# Call the existing function for handling padding_idx
result, offset2bag, bag_size, max_indices = _aten_embedding_bag_1d_padding_idx_onnx(
weight,
indices,
offsets,
mode,
per_sample_weights,
include_last_offset,
padding_idx,
)

return result, offset2bag, bag_size, max_indices

result, offset2bag, bag_size, max_indices = _aten_embedding_bag_1d_padding_idx_onnx(
weight, indices, offsets, mode, per_sample_weights, include_last_offset, padding_idx
# When padding_idx is None, use the standard embedding_bag implementation
result, offset2bag, bag_size, max_indices = _aten_embedding_bag_onnx(
weight,
indices,
offsets,
mode,
per_sample_weights,
include_last_offset,
)

return result, offset2bag, bag_size, max_indices
Expand All @@ -3032,6 +3047,12 @@ def _aten_embedding_bag_1d_padding_idx_onnx(
padding_idx: int,
) -> Tuple[TFloat, TFloat, TFloat, TFloat]:
neg_1 = op.Constant(value_ints=[-1])

num_embeddings = op.Shape(weight, start=0, end=1) # Get number of rows in weight
num_embeddings_scalar = op.Squeeze(num_embeddings)
if padding_idx < 0:
padding_idx = padding_idx + num_embeddings_scalar

# Get weight out according to indices,
# e.g. indices=[3,1,4,5,3] means get weight[[3,1,4,5,3]]
indices_weight = op.Gather(weight, indices)
Expand Down Expand Up @@ -3067,7 +3088,10 @@ def _aten_embedding_bag_1d_padding_idx_onnx(
cond_2 = j < end_pos
while cond_2:
index = op.Gather(indices, j)
if not op.Equal(index, padding_idx):
normalized_index = index
if index < 0:
normalized_index = index + num_embeddings_scalar
if not op.Equal(normalized_index, padding_idx):
# Something like the 'append' operation
curr_offsets = op.Concat(curr_offsets, op.Reshape(j, neg_1), axis=0)
j = j + 1
Expand Down Expand Up @@ -3096,23 +3120,24 @@ def _aten_embedding_bag_1d_padding_idx_onnx(
result = op.CastLike(result, weight)

if mode == 0: # sum
offset2bag = op.Expand(0, op.Shape(indices))
offset2bag = op.Cast(op.Expand(0, op.Shape(indices)), to=INT64.dtype)
if op.Equal(include_last_offset, True):
bag_size = op.Expand(0, op.Shape(offsets))
bag_size = op.Cast(op.Expand(0, op.Shape(offsets)), to=INT64.dtype)
max_indices = op.Cast(op.Expand(0, op.Shape(offsets)), to=INT64.dtype)
else:
bag_size = op.Expand(0, op.Shape(offsets) - 1)
max_indices = op.Expand(0, op.Shape(bag_size))
bag_size = op.Cast(op.Expand(0, op.Shape(offsets) - 1), to=INT64.dtype)
max_indices = op.Cast(op.Expand(0, op.Shape(offsets) - 1), to=INT64.dtype)
elif mode == 1: # mean
offset2bag = op.Expand(0, op.Shape(indices, start=0, end=1))
bag_size = op.Expand(0, op.Shape(offsets) - 1)
max_indices = op.Expand(0, op.Shape(bag_size))
offset2bag = op.Cast(op.Expand(0, op.Shape(indices, start=0, end=1)), to=INT64.dtype)
bag_size = op.Cast(op.Expand(0, op.Shape(offsets) - 1), to=INT64.dtype)
max_indices = op.Cast(op.Expand(0, op.Shape(offsets) - 1), to=INT64.dtype)
else: # mode == 2, max
offset2bag = op.Expand(0, op.Shape(indices, start=0, end=1))
bag_size = op.Expand(0, op.Shape(offsets) - 1)
offset2bag = op.Cast(op.Expand(0, op.Shape(indices, start=0, end=1)), to=INT64.dtype)
bag_size = op.Cast(op.Expand(0, op.Shape(offsets) - 1), to=INT64.dtype)
# shape = (bag_size.dim[0], weight.dim[1])
dim_0 = op.Shape(bag_size, start=0, end=1)
dim_1 = op.Shape(weight, start=1, end=2)
max_indices = op.Expand(0, op.Concat(dim_0, dim_1, axis=0))
max_indices = op.Cast(op.Expand(0, op.Concat(dim_0, dim_1, axis=0)), to=INT64.dtype)

return result, offset2bag, bag_size, max_indices

Expand Down
26 changes: 25 additions & 1 deletion tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,25 @@ def xfail(
# Modify this section ##########################################################


def _embedding_bag_input_wrangler(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Final round of reivews: Is this necessary? I think it should accept a None input?

args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
# ONNX attributes cannot be None; omit padding_idx if it's None.
if "padding_idx" in kwargs:
padding_idx = kwargs.pop("padding_idx")
if padding_idx is not None:
kwargs["padding_idx"] = int(padding_idx)

# Ensure indices/offsets are int64 (positional: weight, indices, offsets, ...)
if len(args) >= 3:
if isinstance(args[1], torch.Tensor):
args[1] = args[1].to(torch.long)
if isinstance(args[2], torch.Tensor):
args[2] = args[2].to(torch.long)

return args, kwargs


def _amin_amax_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
Expand Down Expand Up @@ -908,12 +927,17 @@ def _where_input_wrangler(
core_ops.aten_embedding_bag,
tolerance={torch.float32: (1e-4, 5e-4)},
compare_shape_only_for_output=(1, 2, 3),
).skip(dtypes=(torch.float16,), reason="fixme: results mismatch in torch nightly."),
input_wrangler=_embedding_bag_input_wrangler,
).skip(
dtypes=(torch.float16,),
reason="fixme: results mismatch in torch nightly.",
),
TorchLibOpInfo(
"ops.aten.embedding_bag.padding_idx",
core_ops.aten_embedding_bag_padding_idx,
tolerance={torch.float16: (1e-2, 1e-2)},
compare_shape_only_for_output=(1, 2, 3),
input_wrangler=_embedding_bag_input_wrangler,
),
TorchLibOpInfo(
"ops.aten.embedding_renorm",
Expand Down
Loading