Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Does the synchronized version of abn support fp16 training? #150

Closed
bwang-delft opened this issue Dec 2, 2019 · 8 comments
Closed

Does the synchronized version of abn support fp16 training? #150

bwang-delft opened this issue Dec 2, 2019 · 8 comments

Comments

@bwang-delft
Copy link

The regular abn supports fp16 training but I couldn't make sync_abn work with fp16 training. Did I do something wrong?

@ducksoup
Copy link
Contributor

ducksoup commented Dec 5, 2019

@bwang-delft can you please provide more context about the way you are using fp16 with sync_abn and the errors you are encountering?

@bwang-delft
Copy link
Author

(1) replace every ABN layer with InPlaceABNSync layer in my network
(2) change activation from 'relu' to 'leaky_relu'
(3) cast my model to half()
(4) use 2 gpus, set batch size to 1 on each
(5) start my training script with the following command: python3 -m torch.distributed.launch --nproc_per_node=2 train.py

The training process was successful but failed after the first epoch. I got the following messages:

RuntimeError: mean is not compatible with x (wrong size or scalar type) (forward at src/inplace_abn.cpp:52)
frame #0: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x33 (0x7f1fab3cb813 in /home/beinan/.local/lib/python3.6/site-packages/torch/lib/libc10.so)
frame #1: forward(at::Tensor&, at::Tensor const&, at::Tensor const&, c10::optionalat::Tensor const&, c10::optionalat::Tensor const&, float, Activation, float) + 0x42b (0x7f1f837c886b in /home/beinan/.local/lib/python3.6/site-packages/inplace_abn/_backend.cpython-36m-x86_64-linux-gnu.so)
frame #2: + 0x85619 (0x7f1f837e7619 in /home/beinan/.local/lib/python3.6/site-packages/inplace_abn/_backend.cpython-36m-x86_64-linux-gnu.so)
frame #3: + 0x8126a (0x7f1f837e326a in /home/beinan/.local/lib/python3.6/site-packages/inplace_abn/_backend.cpython-36m-x86_64-linux-gnu.so)
frame #4: /usr/bin/python3() [0x4f8925]
frame #5: _PyEval_EvalFrameDefault + 0x467 (0x4f98c7 in /usr/bin/python3)
frame #6: /usr/bin/python3() [0x4f6128]
frame #7: /usr/bin/python3() [0x56febd]
frame #8: PyObject_Call + 0x3e (0x57c2fe in /usr/bin/python3)
frame #9: THPFunction_apply(_object*, _object*) + 0xa4f (0x7f1ff3ce75ef in /home/beinan/.local/lib/python3.6/site-packages/torch/lib/libtorch_python.so)
frame #10: /usr/bin/python3() [0x4f858d]
frame #11: _PyEval_EvalFrameDefault + 0x467 (0x4f98c7 in /usr/bin/python3)
frame #12: /usr/bin/python3() [0x4f6128]
frame #13: /usr/bin/python3() [0x4f7d60]
frame #14: /usr/bin/python3() [0x4f876d]
frame #15: _PyEval_EvalFrameDefault + 0x467 (0x4f98c7 in /usr/bin/python3)
frame #16: _PyFunction_FastCallDict + 0xf5 (0x4f4065 in /usr/bin/python3)
frame #17: /usr/bin/python3() [0x5a1481]
frame #18: PyObject_Call + 0x3e (0x57c2fe in /usr/bin/python3)
frame #19: _PyEval_EvalFrameDefault + 0x1851 (0x4facb1 in /usr/bin/python3)
frame #20: /usr/bin/python3() [0x4f6128]
frame #21: _PyFunction_FastCallDict + 0x2fe (0x4f426e in /usr/bin/python3)
frame #22: /usr/bin/python3() [0x5a1481]
frame #23: PyObject_Call + 0x3e (0x57c2fe in /usr/bin/python3)
frame #24: /usr/bin/python3() [0x513601]
frame #25: _PyObject_FastCallKeywords + 0x19c (0x57ec0c in /usr/bin/python3)
frame #26: /usr/bin/python3() [0x4f88ba]
frame #27: _PyEval_EvalFrameDefault + 0x467 (0x4f98c7 in /usr/bin/python3)
frame #28: _PyFunction_FastCallDict + 0xf5 (0x4f4065 in /usr/bin/python3)
frame #29: /usr/bin/python3() [0x5a1481]
frame #30: PyObject_Call + 0x3e (0x57c2fe in /usr/bin/python3)
frame #31: _PyEval_EvalFrameDefault + 0x1851 (0x4facb1 in /usr/bin/python3)
frame #32: /usr/bin/python3() [0x4f6128]
frame #33: _PyFunction_FastCallDict + 0x2fe (0x4f426e in /usr/bin/python3)
frame #34: /usr/bin/python3() [0x5a1481]
frame #35: PyObject_Call + 0x3e (0x57c2fe in /usr/bin/python3)
frame #36: /usr/bin/python3() [0x513601]
frame #37: _PyObject_FastCallKeywords + 0x19c (0x57ec0c in /usr/bin/python3)
frame #38: /usr/bin/python3() [0x4f88ba]
frame #39: _PyEval_EvalFrameDefault + 0x467 (0x4f98c7 in /usr/bin/python3)
frame #40: _PyFunction_FastCallDict + 0xf5 (0x4f4065 in /usr/bin/python3)
frame #41: /usr/bin/python3() [0x5a1481]
frame #42: PyObject_Call + 0x3e (0x57c2fe in /usr/bin/python3)
frame #43: _PyEval_EvalFrameDefault + 0x1851 (0x4facb1 in /usr/bin/python3)
frame #44: /usr/bin/python3() [0x4f6128]
frame #45: _PyFunction_FastCallDict + 0x2fe (0x4f426e in /usr/bin/python3)
frame #46: /usr/bin/python3() [0x5a1481]
frame #47: PyObject_Call + 0x3e (0x57c2fe in /usr/bin/python3)
frame #48: /usr/bin/python3() [0x513601]
frame #49: _PyObject_FastCallKeywords + 0x19c (0x57ec0c in /usr/bin/python3)
frame #50: /usr/bin/python3() [0x4f88ba]
frame #51: _PyEval_EvalFrameDefault + 0x467 (0x4f98c7 in /usr/bin/python3)
frame #52: _PyFunction_FastCallDict + 0xf5 (0x4f4065 in /usr/bin/python3)
frame #53: /usr/bin/python3() [0x5a1481]
frame #54: PyObject_Call + 0x3e (0x57c2fe in /usr/bin/python3)
frame #55: _PyEval_EvalFrameDefault + 0x1851 (0x4facb1 in /usr/bin/python3)
frame #56: /usr/bin/python3() [0x4f6128]
frame #57: _PyFunction_FastCallDict + 0x2fe (0x4f426e in /usr/bin/python3)
frame #58: /usr/bin/python3() [0x5a1481]
frame #59: PyObject_Call + 0x3e (0x57c2fe in /usr/bin/python3)
frame #60: /usr/bin/python3() [0x513601]
frame #61: PyObject_Call + 0x3e (0x57c2fe in /usr/bin/python3)
frame #62: _PyEval_EvalFrameDefault + 0x1851 (0x4facb1 in /usr/bin/python3)
frame #63: /usr/bin/python3() [0x4f6128]

_backend.forward(x, mean, var, weight, bias, ctx.eps, ctx.activation, ctx.activation_param)

RuntimeError: mean is not compatible with x (wrong size or scalar type) (forward at src/inplace_abn.cpp:52)
frame #0: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x33 (0x7f1464f6e813 in /home/beinan/.local/lib/python3.6/site-packages/torch/lib/libc10.so)
frame #1: forward(at::Tensor&, at::Tensor const&, at::Tensor const&, c10::optionalat::Tensor const&, c10::optionalat::Tensor const&, float, Activation, float) + 0x42b (0x7f143d36b86b in /home/beinan/.local/lib/python3.6/site-packages/inplace_abn/_backend.cpython-36m-x86_64-linux-gnu.so)
frame #2: + 0x85619 (0x7f143d38a619 in /home/beinan/.local/lib/python3.6/site-packages/inplace_abn/_backend.cpython-36m-x86_64-linux-gnu.so)
frame #3: + 0x8126a (0x7f143d38626a in /home/beinan/.local/lib/python3.6/site-packages/inplace_abn/_backend.cpython-36m-x86_64-linux-gnu.so)
frame #4: /usr/bin/python3() [0x4f8925]
frame #5: _PyEval_EvalFrameDefault + 0x467 (0x4f98c7 in /usr/bin/python3)
frame #6: /usr/bin/python3() [0x4f6128]
frame #7: /usr/bin/python3() [0x56febd]
frame #8: PyObject_Call + 0x3e (0x57c2fe in /usr/bin/python3)
frame #9: THPFunction_apply(_object*, _object*) + 0xa4f (0x7f14ad88a5ef in /home/beinan/.local/lib/python3.6/site-packages/torch/lib/libtorch_python.so)
frame #10: /usr/bin/python3() [0x4f858d]
frame #11: _PyEval_EvalFrameDefault + 0x467 (0x4f98c7 in /usr/bin/python3)
frame #12: /usr/bin/python3() [0x4f6128]
frame #13: /usr/bin/python3() [0x4f7d60]
frame #14: /usr/bin/python3() [0x4f876d]
frame #15: _PyEval_EvalFrameDefault + 0x467 (0x4f98c7 in /usr/bin/python3)
frame #16: _PyFunction_FastCallDict + 0xf5 (0x4f4065 in /usr/bin/python3)
frame #17: /usr/bin/python3() [0x5a1481]
frame #18: PyObject_Call + 0x3e (0x57c2fe in /usr/bin/python3)
frame #19: _PyEval_EvalFrameDefault + 0x1851 (0x4facb1 in /usr/bin/python3)
frame #20: /usr/bin/python3() [0x4f6128]
frame #21: _PyFunction_FastCallDict + 0x2fe (0x4f426e in /usr/bin/python3)
frame #22: /usr/bin/python3() [0x5a1481]
frame #23: PyObject_Call + 0x3e (0x57c2fe in /usr/bin/python3)
frame #24: /usr/bin/python3() [0x513601]
frame #25: _PyObject_FastCallKeywords + 0x19c (0x57ec0c in /usr/bin/python3)
frame #26: /usr/bin/python3() [0x4f88ba]
frame #27: _PyEval_EvalFrameDefault + 0x467 (0x4f98c7 in /usr/bin/python3)
frame #28: _PyFunction_FastCallDict + 0xf5 (0x4f4065 in /usr/bin/python3)
frame #29: /usr/bin/python3() [0x5a1481]
frame #30: PyObject_Call + 0x3e (0x57c2fe in /usr/bin/python3)
frame #31: _PyEval_EvalFrameDefault + 0x1851 (0x4facb1 in /usr/bin/python3)
frame #32: /usr/bin/python3() [0x4f6128]
frame #33: _PyFunction_FastCallDict + 0x2fe (0x4f426e in /usr/bin/python3)
frame #34: /usr/bin/python3() [0x5a1481]
frame #35: PyObject_Call + 0x3e (0x57c2fe in /usr/bin/python3)
frame #36: /usr/bin/python3() [0x513601]
frame #37: _PyObject_FastCallKeywords + 0x19c (0x57ec0c in /usr/bin/python3)
frame #38: /usr/bin/python3() [0x4f88ba]
frame #39: _PyEval_EvalFrameDefault + 0x467 (0x4f98c7 in /usr/bin/python3)
frame #40: _PyFunction_FastCallDict + 0xf5 (0x4f4065 in /usr/bin/python3)
frame #41: /usr/bin/python3() [0x5a1481]
frame #42: PyObject_Call + 0x3e (0x57c2fe in /usr/bin/python3)
frame #43: _PyEval_EvalFrameDefault + 0x1851 (0x4facb1 in /usr/bin/python3)
frame #44: /usr/bin/python3() [0x4f6128]
frame #45: _PyFunction_FastCallDict + 0x2fe (0x4f426e in /usr/bin/python3)
frame #46: /usr/bin/python3() [0x5a1481]
frame #47: PyObject_Call + 0x3e (0x57c2fe in /usr/bin/python3)
frame #48: /usr/bin/python3() [0x513601]
frame #49: _PyObject_FastCallKeywords + 0x19c (0x57ec0c in /usr/bin/python3)
frame #50: /usr/bin/python3() [0x4f88ba]
frame #51: _PyEval_EvalFrameDefault + 0x467 (0x4f98c7 in /usr/bin/python3)
frame #52: _PyFunction_FastCallDict + 0xf5 (0x4f4065 in /usr/bin/python3)
frame #53: /usr/bin/python3() [0x5a1481]
frame #54: PyObject_Call + 0x3e (0x57c2fe in /usr/bin/python3)
frame #55: _PyEval_EvalFrameDefault + 0x1851 (0x4facb1 in /usr/bin/python3)
frame #56: /usr/bin/python3() [0x4f6128]
frame #57: _PyFunction_FastCallDict + 0x2fe (0x4f426e in /usr/bin/python3)
frame #58: /usr/bin/python3() [0x5a1481]
frame #59: PyObject_Call + 0x3e (0x57c2fe in /usr/bin/python3)
frame #60: /usr/bin/python3() [0x513601]
frame #61: PyObject_Call + 0x3e (0x57c2fe in /usr/bin/python3)
frame #62: _PyEval_EvalFrameDefault + 0x1851 (0x4facb1 in /usr/bin/python3)
frame #63: /usr/bin/python3() [0x4f6128]

@bwang-delft
Copy link
Author

I have another question. Why does ABN support relu activation but InPlaceABNSync does not?

@bwang-delft
Copy link
Author

It seems that the first validation loss calculation caused the crash. Training is fine but somehow validation causes an error.

@bwang-delft
Copy link
Author

OK. The cause is model.eval() but why?

@ducksoup
Copy link
Contributor

I have another question. Why does ABN support relu activation but InPlaceABNSync does not?

InPlaceABN and InPlaceABN sync require invertible activation functions such as Leaky ReLU (please refer to our paper for an explanation). ABN is just a wrapper for standard Batch Norm (not in-place) + activation function, so it can also work with ReLU which is not invertible.

Regarding the error you are getting in eval, this is caused by the fact that the current fp16 implementation of InPlaceABN still expects the running mean and running var to be given as fp32 tensors, as is done in mixed precision training. Pure fp16 support will come in the future but I can't give you a precise timeline. In the meanwhile, as a work-around, please run the following code on your model after calling half() on it:

def cast_running_stats(m):
    if isinstance(m, ABN):
        m.running_mean = m.running_mean.float()
        m.running_var = m.running_var.float()

model.apply(cast_running_stats)

@bwang-delft
Copy link
Author

Thanks, your answer is very helpful.

@dongdong93
Copy link

I have another question. Why does ABN support relu activation but InPlaceABNSync does not?

InPlaceABN and InPlaceABN sync require invertible activation functions such as Leaky ReLU (please refer to our paper for an explanation). ABN is just a wrapper for standard Batch Norm (not in-place) + activation function, so it can also work with ReLU which is not invertible.

Regarding the error you are getting in eval, this is caused by the fact that the current fp16 implementation of InPlaceABN still expects the running mean and running var to be given as fp32 tensors, as is done in mixed precision training. Pure fp16 support will come in the future but I can't give you a precise timeline. In the meanwhile, as a work-around, please run the following code on your model after calling half() on it:

def cast_running_stats(m):
    if isinstance(m, ABN):
        m.running_mean = m.running_mean.float()
        m.running_var = m.running_var.float()

model.apply(cast_running_stats)

How about using apex? should we still apply cast_running_stats?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants