Skip to content

Conversation

@sudhakarsingh27
Copy link
Collaborator

Description

cuDNN recently made returning any subset of {Stats, SumExp, Max} possible. This PR adapts TE to always get Stats from cuDNN and Max tensor if return_max_logit=True. (Note that Stats = log(SumExp)+Max)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • fused_attn_f16_arbitrary_seqlen.cu
    • Removed references to SumExp tensor as it's not needed since cuDNN returns Stats by default.
    • set generate_stats=True which forces cuDNN to always return Stats tensor (needed in the backward pass)
  • transformer_engine/pytorch/cpp_extensions/fused_attn.py
    • Remove code that manually did Stats = log(SumExp) + Max since cuDNN returns Stats directly and TE doesn't need SumExp from cuDNN
  • Corresponding documentation

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

sudhakarsingh27 and others added 5 commits February 12, 2026 13:12
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 12, 2026

Greptile Overview

Greptile Summary

This PR adapts TransformerEngine to leverage a recent cuDNN enhancement that allows returning Stats (log(SumExp)+Max) directly, eliminating the need to compute it manually from separate SumExp and Max tensors.

Key Changes:

  • C++/CUDA: Set generate_stats=true unconditionally to always retrieve Stats from cuDNN
  • Removed SumExp tensor creation and management throughout the codebase
  • Updated tensor ordering: now returns (Stats, Max) instead of (Max, SumExp) when return_max_logit=True
  • Python wrapper: Removed manual computation of Stats = log(SumExp) + Max, now uses cuDNN-provided Stats directly
  • Documentation: Updated comments to reflect new tensor order and removal of SumExp

The changes are well-coordinated across the C++/CUDA backend, Python wrapper, and documentation. The backward pass correctly consumes the Stats tensor from aux_ctx_tensors[0], maintaining compatibility with existing code.

Confidence Score: 5/5

  • This PR is safe to merge - it's a clean refactoring that leverages cuDNN functionality with no logical issues found
  • The changes are well-structured and consistent across all layers. The refactoring removes manual computation in favor of cuDNN-provided values, tensor ordering is correctly updated throughout, and the backward pass properly consumes the Stats tensor. No breaking changes or bugs detected.
  • No files require special attention

Important Files Changed

Filename Overview
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu Removes SumExp tensor handling and updates to always generate Stats from cuDNN, with conditional Max tensor when return_max_logit=True. Tensor ordering changed from (Max, SumExp) to (Stats, Max).
transformer_engine/pytorch/cpp_extensions/fused_attn.py Updates Python wrapper to use Stats directly from cuDNN instead of computing Stats = log(SumExp) + Max. Correctly accesses output_tensors[1] for Stats and output_tensors[2] for Max.
transformer_engine/pytorch/csrc/extensions/attention.cpp Updates documentation comments to reflect new tensor order: Stats first, then Max (when return_max_logit=True), removing references to SumExp.

Sequence Diagram

sequenceDiagram
    participant Py as Python Wrapper
    participant CPP as C++ Extension
    participant CUDA as CUDA Kernel
    participant cuDNN as cuDNN API

    Note over Py,cuDNN: Forward Pass with return_max_logit=True

    Py->>CPP: fused_attn_fwd(return_max_logit=True)
    CPP->>CUDA: fused_attn_arbitrary_seqlen_fwd()
    Note over CUDA: Set generate_stats=true (always)
    CUDA->>cuDNN: sdpa_options.set_logit_max(Max)
    CUDA->>cuDNN: mha_graph->sdpa() with generate_stats
    cuDNN-->>CUDA: Returns (O, Stats, Max)
    Note over CUDA: Stats_tuple = (Stats, Max)
    CUDA-->>CPP: output_tensors[0]=O, [1]=Stats, [2]=Max
    CPP-->>Py: output_tensors
    Note over Py: aux_ctx_tensors=[Stats]
    Note over Py: max_logit=amax(Max, dims)
    Py-->>Py: Return (O, aux_ctx_tensors, max_logit)

    Note over Py,cuDNN: Backward Pass

    Py->>CPP: fused_attn_bwd(aux_ctx_tensors)
    Note over CPP: aux_ctx_tensors[0] = Stats
    CPP->>CUDA: Pass Stats as devPtrSoftmaxStats
    CUDA->>cuDNN: dsdpa with Stats tensor
    cuDNN-->>CUDA: Returns (dQ, dK, dV)
    CUDA-->>CPP: Gradients
    CPP-->>Py: (dQ, dK, dV)
Loading

Last reviewed commit: 2d7b51b

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…27/TransformerEngine into fix_return_stats_max_cudnn
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant