Skip to content

Conversation

@studyingeugene
Copy link
Contributor

Summary

This PR refactors the additive noise generation logic in quantize() to improve compatibility with torch.compile / TorchDynamo.

The previous implementation relied on an in-place uniform_() random operation, which can cause graph breaks or compilation issues under torch.compile.
This PR replaces it with a rand_like() formulation that is compile friendly while preserving identical statistical behavior

Observed Issues

In several environments, torch.compile fails when encountering in-place random initialization inside quantize().

assert_size_stride(buf509, (16, 32, 16, 16), (8192, 256, 16, 1), 'torch.ops.aten.uniform.default')
AssertionError: expected size 32==32, stride 1==256 at dim=1; expected size 16==16, stride 512==16 at dim=2; expected size 16==16, stride 32==1 at dim=3
Error in op: torch.ops.aten.uniform.default
This error most often comes from a incorrect fake (aka meta) kernel for a custom op.
Use torch.library.opcheck to test your custom op.

I observed torch.compile(fullgraph=True) failures in some environments originating from torch.ops.aten.uniform.default. TorchInductor asserts an expected size/stride for the temporary buffer at runtime, but the actual stride can differ (e.g., due to layout differences such as channels-last), causing assert_size_stride to fail.

rand_like() avoids the issue because TorchInductor lowers it as an out-of-place value-producing operation, whereas aten.uniform is lowered with static stride assumptions that can conflict with runtime layouts.

What Changed

# before 
noise = torch.empty_like(inputs).uniform_(-half, half)

# after
noise = 2 * half * torch.rand_like(inputs) - half

Error Reproducing

Unfortunately, I’m not able to provide a reliable minimal reproducing code for this issue.

  • The failure appears to be environment-dependent / non-deterministic: I observed it on several GPUs (A1000, A2000, RTX 3090), but it does not reproduce consistently across runs or across machines.
  • I also attempted to extract a minimal standalone snippet that triggers the same TorchInductor aten.uniform stride assertion, but was not successful.

I apologize for the lack of a reproducible test case.

Compatibility

This change preserves:

  • identical noise range [-half, half]
  • identical expected value and variance
  • No functional or numerical behavior change in quantization logic

Thanks for reading

I appreciate your time reviewing this change.

…h.compile

- Replace `empty_like(...).uniform_()` with `rand_like()`-based expression
- Avoid in-place random ops that are problematic for torch.compile / dynamo
- Preserve identical noise range [-half, half] and statistical behavior
@studyingeugene
Copy link
Contributor Author

studyingeugene commented Dec 16, 2025

Note on failing test_compiling

This PR currently does not pass test_compiling. test_entropy_models.py:
Below is an explanation of the observed behavior and a request for maintainer guidance on the preferred direction.

  def test_compiling(self):
      entropy_bottleneck = EntropyBottleneck(128)
      x0 = torch.rand(1, 128, 32, 32)
      x1 = x0.clone()
      x0.requires_grad_(True)
      x1.requires_grad_(True)

      torch.manual_seed(32)
      y0 = entropy_bottleneck(x0)

      m = torch.compile(entropy_bottleneck)

      torch.manual_seed(32)
      y1 = m(x1)

      assert torch.allclose(y0[0], y1[0])
      assert torch.allclose(y0[1], y1[1])

      y0[0].sum().backward()
      y1[0].sum().backward()

      assert torch.allclose(x0.grad, x1.grad)

Observed root cause

The difference stems from how random number generation (RNG) is handled under torch.compile.

  • The previous implementation based on uniform_ is not reliably supported under torch.compile(fullgraph=True).
    In my environments, it either causes compilation failures (e.g., aten.uniform stride assertions) or triggers graph breaks, falling back to vanilla execution for that portion of the graph.

  • As a consequence, in the current test setup (torch.compile(fullgraph=False)), both the vanilla execution and the compiled execution effectively run the RNG path in vanilla mode, and therefore, when the same seed is used, identical random values are often produced.

  • The new implementation based on rand_like, on the other hand, is more readily captured into the compiled graph in my setup, enabling torch.compile(fullgraph=True).

  • Once the RNG operation becomes part of the compiled graph, however, the execution order and RNG consumption can differ between vanilla and compiled execution.

  • As a result, even with the same seed, the compiled execution may produce different random values compared to the vanilla execution.

In short:

  • uniform_: often not reliably captured → fullgraph compilation failure → vanilla/compiled tend to share the same RNG behavior
  • rand_like: more likely captured → fullgraph compilation succeeds → vanilla/compiled RNG behavior may diverge

Trade-off introduced by this PR

Because of this behavior, the change introduces the following trade-off:

Pros

  • Enables torch.compile(fullgraph=True) in setups where aten.uniform was problematic
  • Allows faster training by enabling full-graph compilation
  • Removes a code path associated with compilation failures in some environments
  • This change only affects training-time behavior; inference results remain identical to the previous implementation

Cons

  • Exact RNG equivalence between vanilla and compiled execution is no longer guaranteed
  • While the statistical properties of the noise distribution are preserved, bitwise equality of random values between compiled and non-compiled runs is lost

Request for maintainer guidance

Given this trade-off, I would appreciate guidance on which aspect should be prioritized:

Supporting full graph compilation (torch.compile(fullgraph=True)), or Preserving strict RNG equivalence between vanilla and compiled execution, as currently enforced by test_compiling

I am happy to hear the preferred direction.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant