From 1ce5d1310e994adfa24da0b6d9470219ff7bf462 Mon Sep 17 00:00:00 2001 From: Radomir Djogo Date: Mon, 16 Sep 2024 18:18:55 +0000 Subject: [PATCH] Split hw (re)config between unpack/math threads --- common/inc/cunpack_common.h | 8 ++------ llk_lib/llk_math_common.h | 29 ++++++++++++++++++++++++++--- llk_lib/llk_unpack_common.h | 14 ++++---------- 3 files changed, 32 insertions(+), 19 deletions(-) diff --git a/common/inc/cunpack_common.h b/common/inc/cunpack_common.h index 072eba3..41ac495 100644 --- a/common/inc/cunpack_common.h +++ b/common/inc/cunpack_common.h @@ -241,10 +241,7 @@ namespace ckernel::unpacker ((uint)unpA_dst_format_masked == (uint)DataFormat::Int32) || ((uint)unpB_dst_format_masked == (uint)DataFormat::Int32); - constexpr uint alu_format_mask = ALU_FORMAT_SPEC_REG0_SrcA_MASK | ALU_FORMAT_SPEC_REG1_SrcB_MASK | - ALU_FORMAT_SPEC_REG0_SrcAUnsigned_MASK | ALU_FORMAT_SPEC_REG0_SrcBUnsigned_MASK; - alu_payload.f.ALU_FORMAT_SPEC_REG0_SrcA = unpA_dst_format_masked; - alu_payload.f.ALU_FORMAT_SPEC_REG1_SrcB = row_pool ? ((uint) DataFormat::Float16 | (exp_width<<2)) : unpB_dst_format_masked; + constexpr uint alu_format_mask = ALU_FORMAT_SPEC_REG0_SrcAUnsigned_MASK | ALU_FORMAT_SPEC_REG0_SrcBUnsigned_MASK; if ((uint)unpA_src_format == (uint)DataFormat::UInt8) { alu_payload.f.ALU_FORMAT_SPEC_REG0_SrcAUnsigned = 1; @@ -257,10 +254,9 @@ namespace ckernel::unpacker // NOTE: This assumes these config fields are adjacent and in same register!! static_assert(ALU_ACC_CTRL_Fp32_enabled_ADDR32 == ALU_FORMAT_SPEC_REG0_SrcA_ADDR32); static_assert(ALU_ACC_CTRL_Fp32_enabled_ADDR32 == ALU_ACC_CTRL_SFPU_Fp32_enabled_ADDR32); - constexpr uint alu_dest_format_mask = ALU_ACC_CTRL_INT8_math_enabled_MASK | ALU_ACC_CTRL_SFPU_Fp32_enabled_MASK | ALU_ACC_CTRL_Fp32_enabled_MASK; + constexpr uint alu_dest_format_mask = ALU_ACC_CTRL_SFPU_Fp32_enabled_MASK | ALU_ACC_CTRL_Fp32_enabled_MASK; alu_payload.f.ALU_ACC_CTRL_Fp32_enabled = fp32_dest_acc_en; alu_payload.f.ALU_ACC_CTRL_SFPU_Fp32_enabled = fp32_dest_acc_en; - alu_payload.f.ALU_ACC_CTRL_INT8_math_enabled = int8_math_enabled; constexpr uint alu_stoch_rnd_mask = ALU_ROUNDING_MODE_Fpu_srnd_en_MASK | ALU_ROUNDING_MODE_Gasket_srnd_en_MASK | ALU_ROUNDING_MODE_Packer_srnd_en_MASK; alu_payload.f.ALU_ROUNDING_MODE_Fpu_srnd_en = fpu_srnd_en; alu_payload.f.ALU_ROUNDING_MODE_Gasket_srnd_en = pack_srnd_en; diff --git a/llk_lib/llk_math_common.h b/llk_lib/llk_math_common.h index 874906a..6d566fd 100644 --- a/llk_lib/llk_math_common.h +++ b/llk_lib/llk_math_common.h @@ -14,8 +14,8 @@ using namespace ckernel::math; -template -inline void _llk_math_hw_configure() { +template +inline void _llk_math_hw_configure_(const std::uint32_t srca_data_format, const std::uint32_t srcb_data_format) { //Untilize mode needs dest read access with a stride of 16 //Following bits are needed for enabling stride of 16 cfg_reg_rmw_tensix(untilize_en); @@ -24,6 +24,14 @@ inline void _llk_math_hw_configure() { // Legacy mode for ZEROACC cfg_reg_rmw_tensix(1); + if constexpr (skip_inputs == false){ + TTI_STALLWAIT(p_stall::STALL_CFG, p_stall::MATH); + uint int8_math_enabled = ((uint)(srca_data_format & 0xF) == (uint)DataFormat::Int8) || + ((uint)(srcb_data_format & 0xF) == (uint)DataFormat::Int8) || + ((uint)srca_data_format == (uint)DataFormat::Int32) || + ((uint)srcb_data_format == (uint)DataFormat::Int32); + cfg_reg_rmw_tensix(int8_math_enabled); + } } template @@ -122,12 +130,27 @@ inline void _llk_math_debug_dump_seek_(std::uint8_t offset) { debug_dump_seek(offset); } -//Following functions not needed for blackhole since ALU format is inferred +// Following functions do not need to program ALU_FORMAT_SPEC_REG0_SrcA/ALU_FORMAT_SPEC_REG1_SrcB +// for blackhole since ALU format is inferred inline void _llk_math_reconfig_data_format_srca_(const std::uint32_t srca_data_format) { + TTI_STALLWAIT(p_stall::STALL_CFG, p_stall::MATH); + uint int8_math_enabled = ((uint)(srca_data_format & 0xF) == (uint)DataFormat::Int8) || + ((uint)srca_data_format == (uint)DataFormat::Int32); + cfg_reg_rmw_tensix(int8_math_enabled); } inline void _llk_math_reconfig_data_format_srcb_(const std::uint32_t srcb_data_format) { + TTI_STALLWAIT(p_stall::STALL_CFG, p_stall::MATH); + uint int8_math_enabled = ((uint)(srcb_data_format & 0xF) == (uint)DataFormat::Int8) || + ((uint)srcb_data_format == (uint)DataFormat::Int32); + cfg_reg_rmw_tensix(int8_math_enabled); } inline void _llk_math_reconfig_data_format_(const std::uint32_t srca_data_format, const std::uint32_t srcb_data_format) { + TTI_STALLWAIT(p_stall::STALL_CFG, p_stall::MATH); + uint int8_math_enabled = ((uint)(srca_data_format & 0xF) == (uint)DataFormat::Int8) || + ((uint)(srcb_data_format & 0xF) == (uint)DataFormat::Int8) || + ((uint)srca_data_format == (uint)DataFormat::Int32) || + ((uint)srcb_data_format == (uint)DataFormat::Int32); + cfg_reg_rmw_tensix(int8_math_enabled); } diff --git a/llk_lib/llk_unpack_common.h b/llk_lib/llk_unpack_common.h index e320a1a..2d1df59 100644 --- a/llk_lib/llk_unpack_common.h +++ b/llk_lib/llk_unpack_common.h @@ -82,16 +82,8 @@ inline void _llk_unpack_config_tile_dim_srcb_impl_(const std::uint32_t face_r_di inline void _llk_unpack_reconfig_data_format_srca_impl_(const std::uint32_t unpack_src_format, const std::uint32_t unpack_dst_format, const std::uint32_t tile_size) { - alu_config_u alu_payload = {.val = 0}; - alu_payload.f.ALU_FORMAT_SPEC_REG0_SrcA = unpack_dst_format; - if ((uint)unpack_src_format == (uint)DataFormat::UInt8) { - alu_payload.f.ALU_FORMAT_SPEC_REG0_SrcAUnsigned = 1; - } - alu_payload.f.ALU_ACC_CTRL_INT8_math_enabled = ((uint)(unpack_dst_format & 0xF) == (uint)DataFormat::Int8) || - ((uint)unpack_dst_format == (uint)DataFormat::Int32); - constexpr uint alu_mask = ALU_FORMAT_SPEC_REG0_SrcA_MASK | ALU_FORMAT_SPEC_REG0_SrcAUnsigned_MASK | ALU_ACC_CTRL_INT8_math_enabled_MASK; - cfg_reg_rmw_tensix(alu_payload.val); - + TTI_STALLWAIT(p_stall::STALL_CFG, p_stall::UNPACK0); + cfg_reg_rmw_tensix(((uint)unpack_src_format == (uint)DataFormat::UInt8) ? 1 : 0); cfg_reg_rmw_tensix(unpack_src_format); cfg_reg_rmw_tensix(unpack_dst_format); TT_SETDMAREG(0, LOWER_HALFWORD(tile_size), 0, LO_16(p_gpr_unpack::TILE_SIZE_A)); // update gpr which holds tile size A @@ -99,6 +91,8 @@ inline void _llk_unpack_reconfig_data_format_srca_impl_(const std::uint32_t unpa inline void _llk_unpack_reconfig_data_format_srcb_impl_(const std::uint32_t unpack_src_format, const std::uint32_t unpack_dst_format, const std::uint32_t tile_size) { + TTI_STALLWAIT(p_stall::STALL_CFG, p_stall::UNPACK1); + cfg_reg_rmw_tensix(((uint)unpack_src_format == (uint)DataFormat::UInt8) ? 1 : 0); cfg_reg_rmw_tensix(unpack_src_format); cfg_reg_rmw_tensix(unpack_dst_format); TT_SETDMAREG(0, LOWER_HALFWORD(tile_size), 0, LO_16(p_gpr_unpack::TILE_SIZE_B)); // update gpr which holds tile size B