You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The hook closure tries to capture optim_dict, but if you have more than 1 model_parts, which can be the case with pipeline parallelism, this will capture the last optim_dict, throwing an error on backward as the parameters in the first model part will not be contained in this dict.
Also, the fused backward+optim code doesn't seem to handle gradient clipping.
The text was updated successfully, but these errors were encountered:
throwing an error on backward as the parameters in the first model part will not be contained in this dict.
I tried running with "--experimental.pipeline_parallel_degree 2", "--optimizer.early_step_in_backward",
All optimizers call .step() with changes
Could you give an example to repro the issue, thank you?
the fused backward+optim code doesn't seem to handle gradient clipping.
backward+optim would free gradient during backward(to optimize memory cost) so that there's no gradient for gradient clipping
Hi! For your configuration, will that create a model_parts with len > 1? If not, then it will work correctly, otherwise it should break.
This is because python closures capture by reference to the identifier, not value. When optim_hook is invoked, it will always use the final state of optim_dict after the loop is complete, rather than the state at the time the hook was registered.
I have trouble in producing a repro because my version of torchtitan is heavily modified.
backward+optim would free gradient during backward(to optimize memory cost) so that there's no gradient for gradient clipping
That makes sense. Wouldn't you still want to clip the grad before stepping the optimizer? I understand you can't do grad norm clipping, because you can't compute the norm, but you can use traditional clipping where you clamp the grad by some threshold.
The init method for the
OptimizersInBackwardContainer
has a bug:torchtitan/torchtitan/optimizer.py
Lines 99 to 101 in 2a44370
The hook closure tries to capture
optim_dict
, but if you have more than 1model_parts
, which can be the case with pipeline parallelism, this will capture the lastoptim_dict
, throwing an error on backward as the parameters in the first model part will not be contained in this dict.Also, the fused backward+optim code doesn't seem to handle gradient clipping.
The text was updated successfully, but these errors were encountered: