-
Notifications
You must be signed in to change notification settings - Fork 637
[PyTorch][Fused Attn] Add support for cuDNN to return Softmax Stats always and Max when return_max_logit=True
#2677
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
for more information, see https://pre-commit.ci
Greptile OverviewGreptile SummaryThis PR adapts TransformerEngine to leverage a recent cuDNN enhancement that allows returning Key Changes:
The changes are well-coordinated across the C++/CUDA backend, Python wrapper, and documentation. The backward pass correctly consumes the Confidence Score: 5/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Py as Python Wrapper
participant CPP as C++ Extension
participant CUDA as CUDA Kernel
participant cuDNN as cuDNN API
Note over Py,cuDNN: Forward Pass with return_max_logit=True
Py->>CPP: fused_attn_fwd(return_max_logit=True)
CPP->>CUDA: fused_attn_arbitrary_seqlen_fwd()
Note over CUDA: Set generate_stats=true (always)
CUDA->>cuDNN: sdpa_options.set_logit_max(Max)
CUDA->>cuDNN: mha_graph->sdpa() with generate_stats
cuDNN-->>CUDA: Returns (O, Stats, Max)
Note over CUDA: Stats_tuple = (Stats, Max)
CUDA-->>CPP: output_tensors[0]=O, [1]=Stats, [2]=Max
CPP-->>Py: output_tensors
Note over Py: aux_ctx_tensors=[Stats]
Note over Py: max_logit=amax(Max, dims)
Py-->>Py: Return (O, aux_ctx_tensors, max_logit)
Note over Py,cuDNN: Backward Pass
Py->>CPP: fused_attn_bwd(aux_ctx_tensors)
Note over CPP: aux_ctx_tensors[0] = Stats
CPP->>CUDA: Pass Stats as devPtrSoftmaxStats
CUDA->>cuDNN: dsdpa with Stats tensor
cuDNN-->>CUDA: Returns (dQ, dK, dV)
CUDA-->>CPP: Gradients
CPP-->>Py: (dQ, dK, dV)
Last reviewed commit: 2d7b51b |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 1 comment
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…27/TransformerEngine into fix_return_stats_max_cudnn
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, no comments
Description
cuDNN recently made returning any subset of {Stats, SumExp, Max} possible. This PR adapts TE to always get
Statsfrom cuDNN andMaxtensor ifreturn_max_logit=True. (Note thatStats= log(SumExp)+Max)Type of change
Changes
Please list the changes introduced in this PR:
fused_attn_f16_arbitrary_seqlen.cuSumExptensor as it's not needed since cuDNN returnsStatsby default.generate_stats=Truewhich forces cuDNN to always returnStatstensor (needed in the backward pass)transformer_engine/pytorch/cpp_extensions/fused_attn.pyStats = log(SumExp) + Maxsince cuDNN returnsStatsdirectly and TE doesn't needSumExpfrom cuDNNChecklist: