Skip to content

[MX] Rowwise W cached for backwards #2546

@jeromeku

Description

@jeromeku

Describe the bug

In transformer_engine.pytorch.Linear, both rowwise and columnwise quantized W are saved for backwards.

However, only columnwise W is needed for backwards for mx mixed precision.

Notably, rowwise is discarded and only columnwise is saved in the newer transformer_engine.pytorch.ops.BasicLinear.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions