Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#15713: Use ZEROACC properly for FP32 case #59

Merged
merged 3 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions common/inc/cmath_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -228,4 +228,27 @@ inline constexpr int get_math_fidelity_increment(const int math_fidelity_desc)
return ((math_fidelity_desc >> 3) & 0x1) + 1;
}

// Returns DEST base in faces for 16-bit DEST mode. Half of the DEST can store 32 faces,
// so "base in faces" is whatever get_dest_buffer_base returns, divided by 16.
inline std::uint32_t get_dest_buffer_base_16b()
{
return (get_dest_buffer_base() >> 4);
}

// Returns DEST base in faces for 32-bit DEST mode. Half of the DEST can store 16 faces,
// so "base in faces" is whatever get_dest_buffer_base returns, divided by 32.
inline std::uint32_t get_dest_buffer_base_32b()
{
return (get_dest_buffer_base() >> 5);
}

// Returns the offset represented in DEST rows for a given face of a given tile.
inline std::uint32_t get_dest_index_in_faces(const std::uint32_t dst_index, const std::uint32_t face_index)
{
// dst_index << 2 gives a tile idex in faces, because there are 4 faces in a tile.
// face_index should normally take values from {0, 1, 2, 3}, although if it's greater
// than 3 faces from next tiles can be accessed.
return (dst_index << 2) + face_index;
}

} // namespace ckernel::math
55 changes: 27 additions & 28 deletions llk_lib/llk_math_eltwise_binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,21 +81,21 @@ inline void _llk_math_eltwise_binary_(const std::uint32_t num_faces, uint dst_in
// Mop for col broadcast only does 2 outerloops. Needs to clear B manually and call twice
constexpr uint32_t outerloop = (binary_reuse_dest != EltwiseBinaryReuseDestType::NONE) ? 2 : 1;
#pragma GCC unroll 0
for (std::uint32_t n = 0; n < outerloop; n++) { // N-num faces
for (std::uint32_t face_num = 0; face_num< outerloop; face_num++) { // N-num faces
eltwise_binary_reuse_dest_as_src<binary_reuse_dest>();
ckernel_template::run(instrn_buffer);
}
TTI_SETRWC(p_setrwc::CLR_B, 0, 0, 0, 0, 0);
#pragma GCC unroll 0
for (std::uint32_t n = 0; n < outerloop; n++) { // N-num faces
for (std::uint32_t face_num = 0; face_num< outerloop; face_num++) { // N-num faces
eltwise_binary_reuse_dest_as_src<binary_reuse_dest>();
ckernel_template::run(instrn_buffer);
}
TTI_SETRWC(p_setrwc::CLR_B, 0, 0, 0, 0, 0);
} else {
constexpr uint32_t outerloop = (binary_reuse_dest != EltwiseBinaryReuseDestType::NONE) ? 4 : 1;
#pragma GCC unroll 0
for (std::uint32_t n = 0; n < outerloop; n++) { // N-num faces
for (std::uint32_t face_num = 0; face_num< outerloop; face_num++) { // N-num faces
ncvetkovicTT marked this conversation as resolved.
Show resolved Hide resolved
eltwise_binary_reuse_dest_as_src<binary_reuse_dest>();
ckernel_template::run(instrn_buffer);
}
Expand All @@ -110,29 +110,28 @@ inline void _llk_math_eltwise_binary_(const std::uint32_t num_faces, uint dst_in
constexpr uint32_t outerloop = (binary_reuse_dest != EltwiseBinaryReuseDestType::NONE) ? 2 : 1;
if constexpr (high_fidelity) {
#pragma GCC unroll 0
for (std::uint32_t n = 0; n < 2; n++) { // N-num faces
for (std::uint32_t face_num = 0; face_num < 2; face_num++) { // N-num faces
eltwise_binary_reuse_dest_as_src<binary_reuse_dest>();
if constexpr (binary_reuse_dest != EltwiseBinaryReuseDestType::NONE) {
// fp32 zeroacc can only clear 8x16 datums at a time, need to call twice per 16x16 face
// We clear the DEST face-by-face, given the DEST base, tile index and face index
if (is_fp32_dest_acc_en && clear_fp32_dst_acc) {
TT_ZEROACC(ZERO_ACC_MODE, 1/*clear fp32*/, 0, ADDR_MOD_1, ((get_dest_buffer_base() >> 4) + (dst_index << 3)) + (0 + n*2)); // Clear lower half of faces 0 & 1 (offsets 0, 2)
TT_ZEROACC(ZERO_ACC_MODE, 1/*clear fp32*/, 0, ADDR_MOD_1, ((get_dest_buffer_base() >> 4) + (dst_index << 3)) + (0 + ((n*2)+1))); // Clear upper half of faces 0 & 1 (offsets: 1, 3)
TT_ZEROACC(ZERO_ACC_MODE, 1/*clear fp32*/, 0, ADDR_MOD_1, (get_dest_buffer_base_32b() + get_dest_index_in_faces(dst_index, (0 + face_num)))); // Clear faces 0 & 1
} else {
TT_ZEROACC(ZERO_ACC_MODE, 0/*clear fp32*/, 0, ADDR_MOD_1, ((get_dest_buffer_base() >> 4) + (dst_index << 2)) + (0 + n)); // Clear faces 0 & 1
TT_ZEROACC(ZERO_ACC_MODE, 0/*clear fp32*/, 0, ADDR_MOD_1, (get_dest_buffer_base_16b() + get_dest_index_in_faces(dst_index, (0 + face_num)))); // Clear faces 0 & 1
}
}
ckernel_template::run(instrn_buffer);
}
} else {
#pragma GCC unroll 0
for (std::uint32_t n = 0; n < outerloop; n++) { // N-num faces
for (std::uint32_t face_num = 0; face_num< outerloop; face_num++) { // N-num faces
eltwise_binary_reuse_dest_as_src<binary_reuse_dest>();
if constexpr (binary_reuse_dest != EltwiseBinaryReuseDestType::NONE) {
// We clear the DEST face-by-face, given the DEST base, tile index and face index
if (is_fp32_dest_acc_en && clear_fp32_dst_acc) {
TT_ZEROACC(ZERO_ACC_MODE, 1/*clear fp32*/, 0, ADDR_MOD_1, ((get_dest_buffer_base() >> 4) + (dst_index << 3)) + (0 + n*2)); // Clear lower half of faces 0 & 1
TT_ZEROACC(ZERO_ACC_MODE, 1/*clear fp32*/, 0, ADDR_MOD_1, ((get_dest_buffer_base() >> 4) + (dst_index << 3)) + (0 + ((n*2)+1))); // Clear upper half of faces 0 & 1
TT_ZEROACC(ZERO_ACC_MODE, 1/*clear fp32*/, 0, ADDR_MOD_1, (get_dest_buffer_base_32b() + get_dest_index_in_faces(dst_index, (0 + face_num)))); // Clear faces 0 & 1
} else {
TT_ZEROACC(ZERO_ACC_MODE, 0/*clear fp32*/, 0, ADDR_MOD_1, ((get_dest_buffer_base() >> 4) + (dst_index << 2)) + (0 + n)); // Clear faces 0 & 1
TT_ZEROACC(ZERO_ACC_MODE, 0/*clear fp32*/, 0, ADDR_MOD_1, (get_dest_buffer_base_16b() + get_dest_index_in_faces(dst_index, (0 + face_num)))); // Clear faces 0 & 1
}
}
ckernel_template::run(instrn_buffer);
Expand All @@ -141,28 +140,28 @@ inline void _llk_math_eltwise_binary_(const std::uint32_t num_faces, uint dst_in
TTI_SETRWC(p_setrwc::CLR_B, 0, 0, 0, 0, 0);
if constexpr (high_fidelity) {
#pragma GCC unroll 0
for (std::uint32_t n = 0; n < 2; n++) { // N-num faces
for (std::uint32_t face_num = 0; face_num < 2; face_num++) { // N-num faces
eltwise_binary_reuse_dest_as_src<binary_reuse_dest>();
if constexpr (binary_reuse_dest != EltwiseBinaryReuseDestType::NONE) {
// We clear the DEST face-by-face, given the DEST base, tile index and face index
if (is_fp32_dest_acc_en && clear_fp32_dst_acc) {
TT_ZEROACC(ZERO_ACC_MODE, 1/*clear fp32*/, 0, ADDR_MOD_1, ((get_dest_buffer_base() >> 4) + (dst_index << 3)) + (4 + n*2)); // Clear lower half of faces 2 & 3 (offsets: 4, 6)
TT_ZEROACC(ZERO_ACC_MODE, 1/*clear fp32*/, 0, ADDR_MOD_1, ((get_dest_buffer_base() >> 4) + (dst_index << 3)) + (4 + ((n*2)+1))); // Clear upper half of faces 2 & 3 (offsets: 5, 7)
TT_ZEROACC(ZERO_ACC_MODE, 1/*clear fp32*/, 0, ADDR_MOD_1, (get_dest_buffer_base_32b() + get_dest_index_in_faces(dst_index, (2 + face_num)))); // Clear faces 2 & 3
} else {
TT_ZEROACC(ZERO_ACC_MODE, 0/*clear fp32*/, 0, ADDR_MOD_1, ((get_dest_buffer_base() >> 4) + (dst_index << 2)) + (2 + n)); // Clear faces 2 & 3
TT_ZEROACC(ZERO_ACC_MODE, 0/*clear fp32*/, 0, ADDR_MOD_1, (get_dest_buffer_base_16b() + get_dest_index_in_faces(dst_index, (2 + face_num)))); // Clear faces 2 & 3
}
}
ckernel_template::run(instrn_buffer);
}
} else {
#pragma GCC unroll 0
for (std::uint32_t n = 0; n < outerloop; n++) { // N-num faces
for (std::uint32_t face_num = 0; face_num< outerloop; face_num++) { // N-num faces
eltwise_binary_reuse_dest_as_src<binary_reuse_dest>();
if constexpr (binary_reuse_dest != EltwiseBinaryReuseDestType::NONE) {
// We clear the DEST face-by-face, given the DEST base, tile index and face index
if (is_fp32_dest_acc_en && clear_fp32_dst_acc) {
TT_ZEROACC(ZERO_ACC_MODE, 1/*clear fp32*/, 0, ADDR_MOD_1, ((get_dest_buffer_base() >> 4) + (dst_index << 3)) + (4 + n*2)); // Clear lower half of faces 2 & 3 (offsets: 4, 6)
TT_ZEROACC(ZERO_ACC_MODE, 1/*clear fp32*/, 0, ADDR_MOD_1, ((get_dest_buffer_base() >> 4) + (dst_index << 3)) + (4 + ((n*2)+1))); // Clear upper half of faces 2 & 3 (offsets: 5, 7)
TT_ZEROACC(ZERO_ACC_MODE, 1/*clear fp32*/, 0, ADDR_MOD_1, (get_dest_buffer_base_32b() + get_dest_index_in_faces(dst_index, (2 + face_num)))); // Clear faces 2 & 3
} else {
TT_ZEROACC(ZERO_ACC_MODE, 0/*clear fp32*/, 0, ADDR_MOD_1, ((get_dest_buffer_base() >> 4) + (dst_index << 2)) + (2 + n)); // Clear faces 2 & 3
TT_ZEROACC(ZERO_ACC_MODE, 0/*clear fp32*/, 0, ADDR_MOD_1, (get_dest_buffer_base_16b() + get_dest_index_in_faces(dst_index, (2 + face_num)))); // Clear faces 2 & 3
}
}
ckernel_template::run(instrn_buffer);
Expand All @@ -174,28 +173,28 @@ inline void _llk_math_eltwise_binary_(const std::uint32_t num_faces, uint dst_in
const uint32_t outerloop = (binary_reuse_dest != EltwiseBinaryReuseDestType::NONE) ? num_faces : 1;
if constexpr (high_fidelity) {
#pragma GCC unroll 0
for (std::uint32_t n = 0; n < num_faces; n++) { // N-num faces
for (std::uint32_t face_num = 0; face_num < num_faces; face_num++) { // N-num faces
eltwise_binary_reuse_dest_as_src<binary_reuse_dest>();
if constexpr (binary_reuse_dest != EltwiseBinaryReuseDestType::NONE) {
// We clear the DEST face-by-face, given the DEST base, tile index and face index
if (is_fp32_dest_acc_en && clear_fp32_dst_acc) {
TT_ZEROACC(ZERO_ACC_MODE, 1/*clear fp32*/, 0, ADDR_MOD_1, ((get_dest_buffer_base() >> 4) + (dst_index << 3)) + n*2);
TT_ZEROACC(ZERO_ACC_MODE, 1/*clear fp32*/, 0, ADDR_MOD_1, ((get_dest_buffer_base() >> 4) + (dst_index << 3)) + ((n*2)+1));
TT_ZEROACC(ZERO_ACC_MODE, 1/*clear fp32*/, 0, ADDR_MOD_1, (get_dest_buffer_base_32b() + get_dest_index_in_faces(dst_index, face_num)));
} else {
TT_ZEROACC(ZERO_ACC_MODE, 0/*clear fp32*/, 0, ADDR_MOD_1, ((get_dest_buffer_base() >> 4) + (dst_index << 2)) + n);
TT_ZEROACC(ZERO_ACC_MODE, 0/*clear fp32*/, 0, ADDR_MOD_1, (get_dest_buffer_base_16b() + get_dest_index_in_faces(dst_index, face_num)));
}
}
ckernel_template::run(instrn_buffer);
}
} else {
#pragma GCC unroll 0
for (std::uint32_t n = 0; n < outerloop; n++) { // N-num faces
for (std::uint32_t face_num = 0; face_num< outerloop; face_num++) { // N-num faces
eltwise_binary_reuse_dest_as_src<binary_reuse_dest>();
if constexpr (binary_reuse_dest != EltwiseBinaryReuseDestType::NONE) {
// We clear the DEST face-by-face, given the DEST base, tile index and face index
if (is_fp32_dest_acc_en && clear_fp32_dst_acc) {
TT_ZEROACC(ZERO_ACC_MODE, 1/*clear fp32*/, 0, ADDR_MOD_1, ((get_dest_buffer_base() >> 4) + (dst_index << 3)) + n*2);
TT_ZEROACC(ZERO_ACC_MODE, 1/*clear fp32*/, 0, ADDR_MOD_1, ((get_dest_buffer_base() >> 4) + (dst_index << 3)) + ((n*2)+1));
TT_ZEROACC(ZERO_ACC_MODE, 1/*clear fp32*/, 0, ADDR_MOD_1, (get_dest_buffer_base_32b() + get_dest_index_in_faces(dst_index, face_num)));
} else {
TT_ZEROACC(ZERO_ACC_MODE, 0/*clear fp32*/, 0, ADDR_MOD_1, ((get_dest_buffer_base() >> 4) + (dst_index << 2)) + n);
TT_ZEROACC(ZERO_ACC_MODE, 0/*clear fp32*/, 0, ADDR_MOD_1, (get_dest_buffer_base_16b() + get_dest_index_in_faces(dst_index, face_num)));
}
}
ckernel_template::run(instrn_buffer);
Expand Down
Loading