-
Notifications
You must be signed in to change notification settings - Fork 0
Reproduce PlantCAD1 using Marin #2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
95d8d3a to
8bd3aa7
Compare
8bd3aa7 to
d0509d3
Compare
80c49b4 to
77ff9ec
Compare
|
Exciting!
How different is this eval from the one with ~69 AUROC here? Fig. from PlantCad2 paper also relevant here, I believe it's the same eval: |
Very different! I was originally comparing to the XGBoost fine-tuning that scored ~89 (by mistake), not the zero-shot performance. I updated my notes on that and added more context on past performance for the task + compute estimates. This feels much more within reach at 69 ROC. The linear extrapolation actually suggests needing something very close to the same compute as PlantCAD1, which would be interesting if it worked out that way. |
|
Some potential modifications that I think could be worthwhile for DNA (potentially helping reach PCAD1 performance faster):
|
77ff9ec to
3ff862e
Compare
1d4dda5 to
2537367
Compare
|
Here are some results from a second iteration on this with a few changes (thanks @dlwh and @gonzalobenegas for suggestions):
The global batch size (2,048 examples, ~1M tokens) stayed the same as did the total token count (28B, 10 epochs). Overall, this performed a good bit better despite a gigantic training loss spike that I didn't catch until the run had mostly recovered. I just let it finish out rather than restarting it:
Details
Despite that giant loss spike, this trend appears better and I have to imagine any amount of hyperparam optimization would put the PlantCAD1 performance within reach. That might even be possible within this same compute budget, which was around 6 hrs for a single 8xH100 instance on CoreWeave (~$300). The run crashed a couple times on its own and I also killed it twice intentionally, so my actual costs were maybe double that. Either way, this seems like a decent step forward. |
33c5dc6 to
4e428d6
Compare
|
In the previous iteration, I did not mention using ROC numbers computed during training. Everything shown there was generated from checkpoints in a separate process. I did that because there is some kind of bug with the in-training version of this eval that I can't figure out. While I still don't know what's going on there, I added a test at 4e428d6 in the meantime to ensure exact parity (to For posterity, here is the script that I used to generate data for this unit test (now at plantcad/ci): |
4b36e7b to
938c00b
Compare
|
Here are results from a third iteration with:
Otherwise, no significant aspects of the training process have changed (same base LR + schedule, same batch size, etc.). First, ROC scores look like they might be saturating now. It still seems kind of hard to say for sure. They are, however, leveling off at scores competitive with both of the largest PlantCAD1/2 models:
source: plantcad_scaling.pdf The most likely reason that the PlantCAD2 scores aren't higher in this case is that PlantCAD2 was trained on much longer sequences than those used in this PlantCAD1 evaluation. I.e. the ROC = .72 number I used in the previous iteration was based on a version of this same eval built with 8192bp sequences (cf. plantcad/PlantCAD2_zero_shot_tasks). PlantCAD reference model scores
Second, the correlation between eval loss/perplexity and ROC is high, but the relationship between them is almost certainly not linear. Log-log plots of step vs eval loss appear to be showing some saturation now (after 12 epochs or so), and the improvements in ROC scores are slowing even faster:
source: plantcad_loss_vs_roc.pdf Log loss plotSame as the above with loss (x-axis) on log2 scale:
source: plantcad_loss_vs_roc_log2.pdf Third, the hot/cold model gap here is pretty striking on eval loss:
source: plantcad_loss_vs_lr.pdf Those downward spikes correspond to evals run at the end of cooldowns whereas the other data points were evals on hot models. The cooldowns basically "fast-forward" the whole training process something like 1-5 full epochs (or more in some cases), which is a much bigger effect than I would have guessed. That gap seems to widen near the end of training too. commit: 219cb84 Details
|
|
Fascinating, I don't think I've seen this kind of analysis before! I'd love to see how they perform on the maize allele frequency eval that should soon be ready in subsampled form. The plot says PlantCAD2 (8129bp) but maybe you meant 8192? |
Whoops, thanks! Fixed that. |
) Bump Levanter from c30de5b to 6cd783c (marin-community#1288), Dolma from fd431d0 to 79ce49d (#2), removing `enable_logprobs` parameter from both `InferenceEngineConfig` and `Request` classes. The logprobs functionality was removed in Levanter's inference engine refactor (marin-community#1277). 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-authored-by: Claude <noreply@anthropic.com>






refs: https://github.com/Open-Athena/oa-cornell-dna/issues/65
This PR includes changes or examples necessary to train and evaluate a Levanter model on DNA as discussed more in marin-community#1729.
These changes are all much more experimental than I imagine any PR to the main Marin repo being, i.e. there are more hacks/patches deeper in the levanter and marin code that I'm not yet sure are worth preserving.
Contributions
These changes mostly include groundwork like:
Results
Here is what conservation ROC looks like as a function of compute for the experiment I ran:
source: plantcad_scaling.py.zip | plantcad_scaling.txt
For context, here are scores on this task from previous models:
And some rough estimates for compute used in pretraining based on an assumption of 50% MFU (see
Notes / Scratchbelow for details):This extrapolation suggests that we're not actually too far off from the PlantCAD1 zero-shot performance, i.e. that we might even get there as soon as 4.5x10^20 FLOPS when assuming the more optimistic sigmoid curve. That would be faster than PlantCAD1, which seems unlikely. The more pessimistic linear estimate suggests a cross at 70 ROC around ~1.4x10^21 FLOPS, and that would be remarkably close to the PlantCAD1 pretraining compute (seems more likely).
Here are some significant caveats with this estimate:
Training and eval loss (log/log) were stable and follow predictable trends after a couple thousand steps (wandb run):
MFU for this 300M model is still ugly (~20%):
Misc
Torch vs Haliax Evals
Here is the commit containing the original Haliax port of the torch causal LM eval: 1fa3a22. For my own sake, here is the original
evaluation.pyimplementation usingbiofoundation: https://gist.github.com/eric-czech/63d6f5079bf91895cf93cf248ea988cb.I ran both versions on the same checkpoints and the ROC scores are all the same +/- .01. They're included below along with the corresponding checkpoint step, LR and flop counts from W&B:
Eval result comparison
Haliax/jax eval results:
Biofoundation/torch eval results:
Notes
Scratch
PlantCAD1 FLOPS estimate:
PlantCAD2 FLOPS estimate:
Checkpoint upload code:
Selective checkpoint download code: