-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AMD] Support fp16 upcast in scaled dot #5543
base: main
Are you sure you want to change the base?
Conversation
AMD gfx9 architectures do not have native bf16 VALU instructions so doing bf16 scaling can be expensive. This commit prototypes upcasting to fp16 for computation. It would mean relaxing to support fp16 in dot_scaled frontend and upcast_mxfp op definitions. Right now the fp16 path is turned on if one input is fp16 for prototyping. A more proper way might be introducing a `math_dtype` to explicitly control.
10044b7
to
2e6f6ee
Compare
34259a4
to
d085268
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just one small question.
FWIW, once #5475 becomes a thing all this will be trivial to implement. That PR is still very much WIP tho :)
RankedTensorType | ||
UpcastMXFPOp::deduceOutputType(TypedValue<RankedTensorType> inputTensor, | ||
ScaleDotElemType inputElemType, | ||
Type outputElemType) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why this change?
Software emulation enables targeting hardware architectures without native microscaling | ||
operation support. Right now for such case, microscaled lhs/rhs are upcasted to | ||
:code:`bf16` element type beforehand for dot computation, with one exception: | ||
for AMD CDNA3 specifically, if one of the inputs is of normal :code:`fp16` element type, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit.
for AMD CDNA3 specifically, if one of the inputs is of normal :code:`fp16` element type, | |
for AMD CDNA3 specifically, if one of the inputs is of :code:`fp16` element type, |
AMD gfx9 architectures do not have native bf16 VALU instructions so doing bf16 scaling can be expensive.
This commit prototypes upcasting to fp16 for computation. It would mean relaxing to support fp16 in dot_scaled frontend and upcast_mxfp op definitions.
Right now the fp16 path is turned on if one input is fp16 for prototyping. A more proper way might be introducing a
math_dtype
to explicitly control.