Skip to content

Conversation

@eric-czech
Copy link
Member

@eric-czech eric-czech commented Sep 30, 2025

refs: https://github.com/Open-Athena/oa-cornell-dna/issues/65

This PR includes changes or examples necessary to train and evaluate a Levanter model on DNA as discussed more in marin-community#1729.

These changes are all much more experimental than I imagine any PR to the main Marin repo being, i.e. there are more hacks/patches deeper in the levanter and marin code that I'm not yet sure are worth preserving.

Contributions

These changes mostly include groundwork like:

Results

Here is what conservation ROC looks like as a function of compute for the experiment I ran:

plantcad_scaling

source: plantcad_scaling.py.zip | plantcad_scaling.txt

For context, here are scores on this task from previous models:

  • PlantCAD1 zero-shot: ~69 ROC on this zero-shot task [1]
  • PlantCAD2 zero-shot: ~72 ROC [2 / Figure 2B]
  • PlantCAD1 XGboost FT: ~89 ROC [3 / Figure 3B]

And some rough estimates for compute used in pretraining based on an assumption of 50% MFU (see Notes / Scratch below for details):

  • PlantCAD1: ~1.5x10^21 FLOPS
  • PlantCAD2: ~1.5x10^23 FLOPS

This extrapolation suggests that we're not actually too far off from the PlantCAD1 zero-shot performance, i.e. that we might even get there as soon as 4.5x10^20 FLOPS when assuming the more optimistic sigmoid curve. That would be faster than PlantCAD1, which seems unlikely. The more pessimistic linear estimate suggests a cross at 70 ROC around ~1.4x10^21 FLOPS, and that would be remarkably close to the PlantCAD1 pretraining compute (seems more likely).

Here are some significant caveats with this estimate:

  1. This includes 10 epochs (~23B tokens) and I'm not making any attempt to account for diminishing returns over epochs
  2. I used the default cosine LR schedule; WSD or constant w/ warmup might be have given more reliable estimates
  3. I did no significant hyperparameter tuning first; I'm blocked ATM on model sizing given how low MFU is for these tiny 512bp sequences and small models (order ~100M params)
  4. I have no idea what the final functional form should look like, so a sigmoid is just a guess since it fits this data reasonably well so far

Training and eval loss (log/log) were stable and follow predictable trends after a couple thousand steps (wandb run):

Screenshot 2025-09-30 at 1 40 15 PM

MFU for this 300M model is still ugly (~20%):

Screenshot 2025-09-30 at 2 16 45 PM

Misc

Torch vs Haliax Evals

Here is the commit containing the original Haliax port of the torch causal LM eval: 1fa3a22. For my own sake, here is the original evaluation.py implementation using biofoundation: https://gist.github.com/eric-czech/63d6f5079bf91895cf93cf248ea988cb.

I ran both versions on the same checkpoints and the ROC scores are all the same +/- .01. They're included below along with the corresponding checkpoint step, LR and flop counts from W&B:

Eval result comparison

Haliax/jax eval results:

step,lr,roc_auc,flops
1673,9.93770809145e-05,0.535217,3.4043190387702497e+18
3346,9.7036921943e-05,0.546725,6.804570790755828e+18
5019,9.30542882997e-05,0.549917,1.0208889829526075e+19
6692,8.75831901794e-05,0.558042,1.3613208868296327e+19
8365,8.08332551969e-05,0.56029,1.701142697688957e+19
10038,7.30716710677e-05,0.565785,2.041574601565982e+19
11711,6.46127955405e-05,0.570048,2.3818031411037733e+19
13384,5.57357234356e-05,0.576358,2.722438409320032e+19
15057,4.68768230348e-05,0.583593,3.0624635845185896e+19
16730,3.83349033654e-05,0.585834,3.402488759717147e+19
18403,3.04164732369e-05,0.589215,3.742717299254939e+19
20076,2.34794097195e-05,0.588738,4.08294583879273e+19
21749,1.77572965185e-05,0.593178,4.423174378330522e+19

Biofoundation/torch eval results:

step,lr,roc_auc,flops
1673,9.93779613054e-05,0.53528062,3404319038770249728
5019,9.30571259232e-05,0.5486141,10208889829526075392
6692,8.75795158208e-05,0.55787022,13613208868296327168
11711,6.45971667836e-05,0.5689541,23818031411037732864
16730,3.83249971491e-05,0.58587594,34024887597171470336
21749,1.77692745637e-05,0.5934248799999999,44231743783305216000

Notes

Scratch

PlantCAD1 FLOPS estimate:

- 312 TFLOPS bf16 (https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf)
- 3.12e+14 FLOPS * 32 A100s * 3 days * 86400 seconds/day = 2.59x10^21 FLOPS
- 50% MFU ==> ~1.5x10^21 FLOPS

PlantCAD2 FLOPS estimate:

- 1,979 TFLOPS bf16 (https://www.nvidia.com/en-us/data-center/h100/)
- 19.79e+14 FLOPS * 256 H100s * 7 days * 86400 seconds/day = 3.06×10^23
- 50% MFU ==> ~1.5x10^23 FLOPS

Checkpoint upload code:

cd sky_workdir
hf upload plantcad/_dev_marin_plantcad1_v1_train . --include 'local_store/checkpoints/plantcad-train-300m-r02-432442/hf/step-*'

Selective checkpoint download code:

set -ex
cd /home/sky/sky_workdir/local_store/checkpoints/plantcad-train-300m-r02-432442/checkpoints
checkpoints=$(ls -tr | awk 'NR == 1 || NR % 10 == 0')
# step-360
# step-3346
# ...
# step-20076
for checkpoint in $checkpoints; do
  hf upload --repo-type dataset plantcad/_dev_marin_plantcad1_v1_train $checkpoint \
  local_store/checkpoints/plantcad-train-300m-r02-432442/checkpoints/$checkpoint
done

@eric-czech eric-czech force-pushed the plantcad1-repro branch 2 times, most recently from 80c49b4 to 77ff9ec Compare September 30, 2025 16:14
@eric-czech eric-czech marked this pull request as draft September 30, 2025 16:28
@gonzalobenegas
Copy link
Member

gonzalobenegas commented Sep 30, 2025

Exciting!

The original PlantCAD1 model scored ~89 ROC on this task

How different is this eval from the one with ~69 AUROC here?

Fig. from PlantCad2 paper also relevant here, I believe it's the same eval:

image

@eric-czech
Copy link
Member Author

How different is this eval from the one with ~69 AUROC https://github.com/Open-Athena/oa-cornell-dna/issues/67?

Very different! I was originally comparing to the XGBoost fine-tuning that scored ~89 (by mistake), not the zero-shot performance. I updated my notes on that and added more context on past performance for the task + compute estimates. This feels much more within reach at 69 ROC. The linear extrapolation actually suggests needing something very close to the same compute as PlantCAD1, which would be interesting if it worked out that way.

@gonzalobenegas
Copy link
Member

gonzalobenegas commented Sep 30, 2025

Some potential modifications that I think could be worthwhile for DNA (potentially helping reach PCAD1 performance faster):

  • An architecture with alternating global and local/sliding window attention (I believe levanter already has some of these).
  • If downweighting loss on repeat tokens (marked as lowercase) is not yet implemented, IMO full masking of the loss would still result in better performance than giving them no special consideration.

@eric-czech
Copy link
Member Author

eric-czech commented Oct 3, 2025

Here are some results from a second iteration on this with a few changes (thanks @dlwh and @gonzalobenegas for suggestions):

  • 2x model size to try to reach a higher MFU (300M -> 600M, MFU went from ~20% -> ~25%)
  • A cyclic LR schedule with cooldowns prior to each eval (aiming for a WSD-S-like [1] schedule)
  • A higher base learning rate (1e-4 -> 3e-4)

The global batch size (2,048 examples, ~1M tokens) stayed the same as did the total token count (28B, 10 epochs).

Overall, this performed a good bit better despite a gigantic training loss spike that I didn't catch until the run had mostly recovered. I just let it finish out rather than restarting it:

plantcad_reproduction
Details

Despite that giant loss spike, this trend appears better and I have to imagine any amount of hyperparam optimization would put the PlantCAD1 performance within reach. That might even be possible within this same compute budget, which was around 6 hrs for a single 8xH100 instance on CoreWeave (~$300). The run crashed a couple times on its own and I also killed it twice intentionally, so my actual costs were maybe double that. Either way, this seems like a decent step forward.

@eric-czech
Copy link
Member Author

In the previous iteration, I did not mention using ROC numbers computed during training. Everything shown there was generated from checkpoints in a separate process. I did that because there is some kind of bug with the in-training version of this eval that I can't figure out. While I still don't know what's going on there, I added a test at 4e428d6 in the meantime to ensure exact parity (to 1e-3 precision) of ROC scores calculated between the jax/haliax implementation in this PR and the torch evaluation logic in Open-Athena/biofoundation@23f6745.

For posterity, here is the script that I used to generate data for this unit test (now at plantcad/ci):
create_reference_scores.py.zip.

@eric-czech eric-czech marked this pull request as ready for review October 8, 2025 15:37
@eric-czech
Copy link
Member Author

eric-czech commented Oct 9, 2025

Here are results from a third iteration with:

  • Continued training from the final checkpoint in iteration 2 for another 10 epochs
  • Actual eval scores for the conservation data subset run on the existing PlantCAD 1 and 2 models rather than published values
  • A direct comparison of eval loss and eval ROC scores

Otherwise, no significant aspects of the training process have changed (same base LR + schedule, same batch size, etc.).


First, ROC scores look like they might be saturating now. It still seems kind of hard to say for sure. They are, however, leveling off at scores competitive with both of the largest PlantCAD1/2 models:

plantcad_scaling

source: plantcad_scaling.pdf

The most likely reason that the PlantCAD2 scores aren't higher in this case is that PlantCAD2 was trained on much longer sequences than those used in this PlantCAD1 evaluation. I.e. the ROC = .72 number I used in the previous iteration was based on a version of this same eval built with 8192bp sequences (cf. plantcad/PlantCAD2_zero_shot_tasks).

PlantCAD reference model scores
  • PlantCAD1-L (kuleshov-group/PlantCaduceus_l32): 0.6898068
  • PlantCAD1-S (kuleshov-group/PlantCaduceus_l20): 0.60242958
  • PlantCAD2-L (kuleshov-group/PlantCAD2-Large-l48-d1536): 0.67332978
  • PlantCAD2-S (kuleshov-group/PlantCAD2-Small-l24-d0768): 0.62516604

Second, the correlation between eval loss/perplexity and ROC is high, but the relationship between them is almost certainly not linear. Log-log plots of step vs eval loss appear to be showing some saturation now (after 12 epochs or so), and the improvements in ROC scores are slowing even faster:

plantcad_loss_vs_roc

source: plantcad_loss_vs_roc.pdf

Log loss plot

Same as the above with loss (x-axis) on log2 scale:

plantcad_loss_vs_roc_log2

source: plantcad_loss_vs_roc_log2.pdf


Third, the hot/cold model gap here is pretty striking on eval loss:

plantcad_loss_vs_lr

source: plantcad_loss_vs_lr.pdf

Those downward spikes correspond to evals run at the end of cooldowns whereas the other data points were evals on hot models. The cooldowns basically "fast-forward" the whole training process something like 1-5 full epochs (or more in some cases), which is a much bigger effect than I would have guessed. That gap seems to widen near the end of training too.


commit: 219cb84

Details

@gonzalobenegas
Copy link
Member

Fascinating, I don't think I've seen this kind of analysis before! I'd love to see how they perform on the maize allele frequency eval that should soon be ready in subsampled form.

The plot says PlantCAD2 (8129bp) but maybe you meant 8192?

@eric-czech
Copy link
Member Author

eric-czech commented Oct 9, 2025

The plot says PlantCAD2 (8129bp) but maybe you meant 8192?

Whoops, thanks! Fixed that.

eric-czech pushed a commit that referenced this pull request Nov 20, 2025
)

Bump Levanter from c30de5b to 6cd783c (marin-community#1288), Dolma from fd431d0 to
79ce49d (#2), removing `enable_logprobs` parameter from both
`InferenceEngineConfig` and `Request` classes. The logprobs
functionality was removed in Levanter's inference engine refactor
(marin-community#1277).

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants