-
Notifications
You must be signed in to change notification settings - Fork 131
Description
Hello,
TL;DR
I have a couple of recommendations I'd like to discuss:
- Disable caching by default to avoid wrong summary upon module changes and to avoid leaking heavy references into memory
- Directly use models as keys for the cache directory.
- Eventually think of using only weak references in
ModelStatistics, but this seems like too much work at this point.
Explanations
I recently spent a good 5h using gc to understand why my neural network was never garbage collected after I had used torchinfo and deleted everything.
After these three steps:
create net -> compute summary -> delete summary and net
I no longer have any handle on my net and summary, and yet, it was never garbage collected!
I finally figured out that the culprit was _cached_forward_pass
torchinfo/torchinfo/torchinfo.py
Line 51 in e67e748
| _cached_forward_pass: dict[str, list[LayerInfo]] = {} |
Debugging was hard even with gc.get_referrers since LayerInfo has a very uninformative repr which obfuscates the nature of the object to anyone not familiar with internals.
Furthermore, I noticed that the keys to this dict are simple class __name__:
torchinfo/torchinfo/torchinfo.py
Line 271 in e67e748
| model_name = model.__class__.__name__ |
I think it may be more confusing than using plain reference to the model, and it could also create conflict for classes with the same
__name__ (think of unimportant names for classes created dynamically for example, or simply nested classes)
Finally, the cache may lead to incorrect output if the network changes before being "re-summarized" (in effect, it would use the same forward, and so the same summary).
All the best!
Élie