-
Notifications
You must be signed in to change notification settings - Fork 187
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Does the synchronized version of abn support fp16 training? #150
Comments
@bwang-delft can you please provide more context about the way you are using fp16 with sync_abn and the errors you are encountering? |
(1) replace every ABN layer with InPlaceABNSync layer in my network The training process was successful but failed after the first epoch. I got the following messages: RuntimeError: mean is not compatible with x (wrong size or scalar type) (forward at src/inplace_abn.cpp:52)
RuntimeError: mean is not compatible with x (wrong size or scalar type) (forward at src/inplace_abn.cpp:52) |
I have another question. Why does ABN support relu activation but InPlaceABNSync does not? |
It seems that the first validation loss calculation caused the crash. Training is fine but somehow validation causes an error. |
OK. The cause is model.eval() but why? |
InPlaceABN and InPlaceABN sync require invertible activation functions such as Leaky ReLU (please refer to our paper for an explanation). ABN is just a wrapper for standard Batch Norm (not in-place) + activation function, so it can also work with ReLU which is not invertible. Regarding the error you are getting in eval, this is caused by the fact that the current fp16 implementation of InPlaceABN still expects the running mean and running var to be given as fp32 tensors, as is done in mixed precision training. Pure fp16 support will come in the future but I can't give you a precise timeline. In the meanwhile, as a work-around, please run the following code on your model after calling def cast_running_stats(m):
if isinstance(m, ABN):
m.running_mean = m.running_mean.float()
m.running_var = m.running_var.float()
model.apply(cast_running_stats) |
Thanks, your answer is very helpful. |
How about using apex? should we still apply cast_running_stats? |
The regular abn supports fp16 training but I couldn't make sync_abn work with fp16 training. Did I do something wrong?
The text was updated successfully, but these errors were encountered: