Skip to content

Pabce/chipmunc

Repository files navigation

CHIPMUNC

Continuous Hauser-Feshbach Iterative Path Moving for Unenergetic Neutrino-Nucleus Cross-sections

CHIPMUNC is a small research codebase exploring differentiable Monte Carlo for nuclear de-excitation cascades. It simulates gamma-emission paths that transition from a continuum of states to a discrete ladder of levels, and computes path-wise gradients with respect to model parameters using JAX. The goal is to enable likelihood-based learning and calibration of phenomenological ingredients (e.g., level densities and transition strengths) with gradient-based methods.

The implementation combines:

  • A continuum sampler via inverse-CDF root-finding (with jaxopt), including the probability of leaving the continuum.
  • A discrete branching process for the level scheme once the cascade enters the discrete region.
  • Differentiable path operations and implicit differentiation of the inverse-CDF to obtain gradients w.r.t. parameters.

Quick start

The simplest way to see things working is to open the notebook:

  • params_test.ipynb: sampling and gradient descent in action.

If you prefer a pure-Python example, this snippet samples continuum paths, builds the discrete branching trees, samples discrete paths, and stitches full cascades:

from jax import random
from pathgradient import META_PARAMS, NOMINAL_PARAMS, jax_cdf_minimum
from sampling import (
    sample_continuum_path,
    get_discrete_tree_body,
    get_full_discrete_tree_vmap,
    sample_discrete_path,
    stitch_paths,
)

# RNG key and initial (excitation) energy
key = random.PRNGKey(0)
initial_energy = 12.0

# 1) Sample continuum paths (+ gradients)
energies, last_energies, last_idx, continuum_cuts, \
    energy_theta_grads, energy_total_theta_grads, energy_Ei_grads, continuum_cut_grads = \
    sample_continuum_path(
        initial_energy=initial_energy,
        meta_params=META_PARAMS,
        params=NOMINAL_PARAMS,
        key=key,
        cdf_root_fun=jax_cdf_minimum,   # root function for inverse-CDF
        sample_num=256,
        passes=1,
        max_continuum_steps=5,
        enforce_decay_to_discrete=False,
        use_continuum_cut=True,
    )

# 2) Build discrete tree once in the discrete region
tree_body = get_discrete_tree_body(META_PARAMS, NOMINAL_PARAMS)

# 3) Assemble energy-dependent full discrete trees for each endpoint energy
discrete_probs_per_energy, discrete_paths = get_full_discrete_tree_vmap(
    last_energies, tree_body, META_PARAMS, NOMINAL_PARAMS
)

# 4) Sample one discrete path per cascade endpoint and stitch complete energy paths
path_indices, chosen_discrete_paths, chosen_discrete_energies = sample_discrete_path(
    last_energies,
    (discrete_probs_per_energy, discrete_paths),
    META_PARAMS,
    NOMINAL_PARAMS,
    key,
)

full_energy_paths = stitch_paths(energies, chosen_discrete_energies)

Installation

Python 3.10+ is recommended.

python -m venv .venv
source .venv/bin/activate  # on Windows: .venv\\Scripts\\activate

pip install --upgrade pip
pip install numpy scipy matplotlib jax jaxopt

# Optional (Apple Silicon GPU via Metal):
pip install jax-metal

Notes:

  • If you need a different JAX build (e.g., CUDA), consult the official JAX installation instructions.
  • CPU-only installs work fine for these examples.

What’s in the box?

  • pathgradient.py

    • Level density and transition strength models (continuum side).
    • CDF and inverse-CDF ingredients for the continuum sampler.
    • Continuum cut probability and its gradient.
    • JAX-based gradients via jacfwd for: final energy w.r.t. parameters and initial energy.
    • Root finding helpers using jaxopt (e.g., bisection) for inverse-CDF.
  • sampling.py

    • sample_continuum_path(...): sample piecewise continuum segments (with gradients) until the path exits to the discrete region.
    • Discrete branching utilities: get_discrete_tree_head, get_discrete_tree_body, get_full_discrete_tree/_vmap.
    • Discrete sampling: sample_discrete_path and utilities to stitch_paths into a full cascade.
  • main.py

    • Placeholder for a script entry point (examples live primarily in notebooks).
  • Notebooks

    • params_test.ipynb: End-to-end sampling and gradient-based parameter tuning.
    • gibbs_samples.ipynb: Explorations of sampling strategies.
    • deexcitation.ipynb, jax_deex.ipynb, test_pathgradient.ipynb, jaxing.ipynb, bbb.ipynb: Various experiments and derivations.
  • Figures

    • Generated plots are saved under saved_images/ (e.g., gradient_paths.svg, discrete_tree_prob.svg, loss_vs_alpha.svg).

Model ingredients (toy)

The current toy model includes:

  • Continuum level density: a simple combination of a backshifted Fermi-gas term and a dispersion-like baseline controlled by NOMINAL_PARAMS['disp_parameter'].
  • Continuum transition strength: smooth function of gamma energy with parameters alpha and beta.
  • Discrete sector: user-provided energies and pairwise decay widths among the discrete levels.

All of these live in pathgradient.py and are easy to change for different physics assumptions.

Reproducing results and figures

  • Start with params_test.ipynb to reproduce sampling and optimization traces. Many example figures are produced by the notebooks and written to saved_images/.
  • If you run on Apple Silicon, jax-metal can speed up JAX ops in some environments; otherwise CPU works.

Development tips

  • JAX + control flow: the code balances NumPy and JAX where convenient. Some helpers convert arrays back and forth to work around device constraints – this is intentional for simplicity.
  • Root finding for the inverse-CDF uses jaxopt.Bisection by default; other solvers are easy to swap in.
  • The gradients of the implicit inverse-CDF are computed via jacfwd on the CDF and standard implicit differentiation.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published