Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMD] Support fp16 upcast in scaled dot #5543

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

antiagainst
Copy link
Collaborator

AMD gfx9 architectures do not have native bf16 VALU instructions so doing bf16 scaling can be expensive.

This commit prototypes upcasting to fp16 for computation. It would mean relaxing to support fp16 in dot_scaled frontend and upcast_mxfp op definitions.

Right now the fp16 path is turned on if one input is fp16 for prototyping. A more proper way might be introducing a math_dtype to explicitly control.

AMD gfx9 architectures do not have native bf16 VALU instructions
so doing bf16 scaling can be expensive.

This commit prototypes upcasting to fp16 for computation.
It would mean relaxing to support fp16 in dot_scaled frontend
and upcast_mxfp op definitions.

Right now the fp16 path is turned on if one input is fp16 for
prototyping. A more proper way might be introducing a `math_dtype`
to explicitly control.
@antiagainst antiagainst force-pushed the amd-mxfp-fp16 branch 2 times, most recently from 34259a4 to d085268 Compare January 9, 2025 00:10
@antiagainst antiagainst marked this pull request as ready for review January 10, 2025 06:58
Copy link
Contributor

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just one small question.

FWIW, once #5475 becomes a thing all this will be trivial to implement. That PR is still very much WIP tho :)

Comment on lines +386 to +389
RankedTensorType
UpcastMXFPOp::deduceOutputType(TypedValue<RankedTensorType> inputTensor,
ScaleDotElemType inputElemType,
Type outputElemType) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this change?

Software emulation enables targeting hardware architectures without native microscaling
operation support. Right now for such case, microscaled lhs/rhs are upcasted to
:code:`bf16` element type beforehand for dot computation, with one exception:
for AMD CDNA3 specifically, if one of the inputs is of normal :code:`fp16` element type,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit.

Suggested change
for AMD CDNA3 specifically, if one of the inputs is of normal :code:`fp16` element type,
for AMD CDNA3 specifically, if one of the inputs is of :code:`fp16` element type,

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants