diff --git a/src/neural/cuda/fp16_kernels.cu b/src/neural/cuda/fp16_kernels.cu index 11a403b658..c0315ef542 100644 --- a/src/neural/cuda/fp16_kernels.cu +++ b/src/neural/cuda/fp16_kernels.cu @@ -134,7 +134,15 @@ void Se_Fp16_NHWC(int N, int C, int numFc1Out, half* output, const half* skip, const half* input, const half* w1, const half* b1, const half* w2, const half* b2, const half* bPrev) { // TODO: Think of more elegant way to avoid this hardcoding :-/ - if (numFc1Out == 32) { + if (numFc1Out == 16) { + if (C == 64) { + SE_Layer_NHWC<64, 16> + <<>>(output, skip, input, w1, b1, w2, b2, bPrev); + } else { + // TODO: support other channel counts. + throw Exception("channel count unsupported by SE layer"); + } + } else if (numFc1Out == 32) { if (C == 64) { SE_Layer_NHWC<64, 32> <<>>(output, skip, input, w1, b1, w2, b2, bPrev);