Skip to content

Disable forwad pass caching by default #369

@ego-thales

Description

@ego-thales

Hello,

TL;DR

I have a couple of recommendations I'd like to discuss:

  1. Disable caching by default to avoid wrong summary upon module changes and to avoid leaking heavy references into memory
  2. Directly use models as keys for the cache directory.
  3. Eventually think of using only weak references in ModelStatistics, but this seems like too much work at this point.

Explanations

I recently spent a good 5h using gc to understand why my neural network was never garbage collected after I had used torchinfo and deleted everything.

After these three steps:
create net -> compute summary -> delete summary and net
I no longer have any handle on my net and summary, and yet, it was never garbage collected!

I finally figured out that the culprit was _cached_forward_pass

_cached_forward_pass: dict[str, list[LayerInfo]] = {}

Debugging was hard even with gc.get_referrers since LayerInfo has a very uninformative repr which obfuscates the nature of the object to anyone not familiar with internals.

Furthermore, I noticed that the keys to this dict are simple class __name__:

model_name = model.__class__.__name__

I think it may be more confusing than using plain reference to the model, and it could also create conflict for classes with the same __name__ (think of unimportant names for classes created dynamically for example, or simply nested classes)

Finally, the cache may lead to incorrect output if the network changes before being "re-summarized" (in effect, it would use the same forward, and so the same summary).

All the best!
Élie

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions