**Is your feature request related to a problem? Please describe.** Basically, cudnn attention doesn't work when vmapped because of this assertion here: https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/jax/cpp_extensions/attention.py#L563 This limits its usability when one wants to implement pipeline parallelism as it requires vmapping staged layers over stage buffer. **Describe the solution you'd like** Properly implemented vmap rules for the cudnn attention.