diff --git a/onnxscript/optimizer/_constant_folding_test.py b/onnxscript/optimizer/_constant_folding_test.py index ae5c9901b..080af9c2f 100644 --- a/onnxscript/optimizer/_constant_folding_test.py +++ b/onnxscript/optimizer/_constant_folding_test.py @@ -741,6 +741,31 @@ def test_constant_folding_creates_constant_nodes_in_function(self): constant_nodes = [n for n in func.graph if n.op_type == "Constant"] self.assertEqual(len(constant_nodes), 1) + def test_initializer_as_graph_output_is_not_removed(self): + """Test that an initializer that is a graph output is not removed during constant folding.""" + model = """ + + agraph (float[N] x) => (float[N] y, float z) { + constant = Constant () + y = Mul(x, constant) + z = Identity(constant) + } + """ + + optimized = self._fold(model) + # After constant folding, the Identity node should be folded, and 'constant' + # should become an initializer with the output name 'z'. + # The key thing is that this initializer should NOT be removed even though + # the Identity node was folded, because it is a graph output. + self.assertIn("z", optimized.graph.initializers) + # The Identity node should be removed + identity_nodes = [n for n in optimized.graph if n.op_type == "Identity"] + self.assertEqual(len(identity_nodes), 0) + # Verify the graph still has both outputs + output_names = [o.name for o in optimized.graph.outputs] + self.assertIn("y", output_names) + self.assertIn("z", output_names) + if __name__ == "__main__": unittest.main()