Skip to content

[JAX] Support vmapped cudnn attention #2545

@qGentry

Description

@qGentry

Is your feature request related to a problem? Please describe.

Basically, cudnn attention doesn't work when vmapped because of this assertion here:
https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/jax/cpp_extensions/attention.py#L563
This limits its usability when one wants to implement pipeline parallelism as it requires vmapping staged layers over stage buffer.

Describe the solution you'd like

Properly implemented vmap rules for the cudnn attention.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions