Skip to content

Conversation

Copy link

Copilot AI commented Dec 5, 2025

Summary

bm.for_loop accepted a jit parameter but never used it. Passing jit=False had no effect—code was always JIT-compiled.

Type of Change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation update
  • Code refactoring (no functional changes)
  • Performance improvement
  • Test coverage improvement
  • Other (please describe):

Description

What does this PR do?

Implements the jit parameter in bm.for_loop and removes unused parameters remat and unroll_kwargs.

Why is this change needed?

The jit parameter was documented but non-functional. Users debugging code with jit=False still got JIT-compiled traces instead of eager execution.

How does it work?

When jit=False, wraps the brainstate.transform.for_loop call in jax.disable_jit() context manager. Uses identity check (is False) to distinguish explicit False from None (default) or other falsy values.

Changes Made

  • Added jax import for jax.disable_jit() access
  • Implemented conditional JIT control: wraps call in jax.disable_jit() when jit is False
  • Removed unused parameters remat and unroll_kwargs from function signature and documentation
  • Improved jit parameter documentation to clarify usage

Testing

How has this been tested?

import brainpy.math as bm

a = bm.Variable(bm.zeros(1))

def body(x):
    print(f"x = {x}")  # Shows tracers with JIT, actual values without
    a.value += x
    return a.value

# JIT disabled - prints actual values: 0, 1, 2
bm.for_loop(body, operands=bm.arange(3), jit=False)

# JIT enabled (default) - prints: JitTracer<int32[]>
bm.for_loop(body, operands=bm.arange(3))

Test environment:

  • Python version: 3.12.3
  • JAX version: 0.8.1
  • OS: Linux

Test coverage:

  • Added new tests for this change
  • All existing tests pass
  • Manual testing performed

Performance Impact

  • No performance impact
  • Performance improved (provide benchmarks)
  • Performance may be affected (explain why and how much)

Documentation

  • Docstrings added/updated
  • Documentation updated (if needed)
  • Examples added/updated (if applicable)
  • CHANGELOG entry added (for notable changes)

Screenshots (if applicable)

N/A

Breaking Changes

None. Removal of unused parameters remat and unroll_kwargs is non-breaking—they were never implemented.

Checklist

  • My code follows the project's style guidelines
  • I have performed a self-review of my code
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes
  • Any dependent changes have been merged and published
  • I have checked for duplicate PRs/issues
  • I have read and followed the CONTRIBUTING guidelines

Additional Notes

CodeQL security scan: 0 alerts.

Related Issues/PRs

N/A

Original prompt

This section details on the original issue you should resolve

<issue_title>[BUG] bm.for_loop receive jit param but don't handle it</issue_title>
<issue_description>For brainpy.math.for_loop, it receive jit param and the docstring said it is described as "Whether to just-in-time compile the function."

def for_loop(
body_fun: Callable,
operands: Any,
reverse: bool = False,
unroll: int = 1,
remat: bool = False,
jit: Optional[bool] = None,
progress_bar: bool = False,
unroll_kwargs: Optional[Dict] = None,
):

But never handle it in main struct of this func

if not isinstance(operands, (tuple, list)):
operands = (operands,)
return brainstate.transform.for_loop(
warp_to_no_state_input_output(body_fun),
*operands, reverse=reverse, unroll=unroll,
pbar=brainstate.transform.ProgressBar() if progress_bar else None,
)

Perhaps with jax.disable_jit() is a solution.


And also there are some unused params like remat and unroll_kwargs, need to be removed in future.</issue_description>

Comments on the Issue (you are @copilot in this section)


💡 You can make Copilot smarter by setting up custom instructions, customizing its development environment and configuring Model Context Protocol (MCP) servers. Learn more Copilot coding agent tips in the docs.

Co-authored-by: Routhleck <88108241+Routhleck@users.noreply.github.com>
Copilot AI changed the title [WIP] Fix jit parameter handling in bm.for_loop Fix bm.for_loop jit parameter handling and remove unused parameters Dec 5, 2025
Copilot AI requested a review from Routhleck December 5, 2025 05:38
@Routhleck Routhleck marked this pull request as ready for review December 5, 2025 05:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] bm.for_loop receive jit param but don't handle it

2 participants