Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: early_step_in_backward with pipeline parallelism and len(model_parts) > 1 #777

Open
cassanof opened this issue Jan 7, 2025 · 2 comments
Assignees
Labels
bug Something isn't working

Comments

@cassanof
Copy link

cassanof commented Jan 7, 2025

The init method for the OptimizersInBackwardContainer has a bug:

def optim_hook(param) -> None:
optim_dict[param].step()
optim_dict[param].zero_grad()

The hook closure tries to capture optim_dict, but if you have more than 1 model_parts, which can be the case with pipeline parallelism, this will capture the last optim_dict, throwing an error on backward as the parameters in the first model part will not be contained in this dict.

Also, the fused backward+optim code doesn't seem to handle gradient clipping.

@mori360
Copy link
Contributor

mori360 commented Jan 7, 2025

throwing an error on backward as the parameters in the first model part will not be contained in this dict.

I tried running with "--experimental.pipeline_parallel_degree 2", "--optimizer.early_step_in_backward",
All optimizers call .step() with changes
Could you give an example to repro the issue, thank you?

the fused backward+optim code doesn't seem to handle gradient clipping.

backward+optim would free gradient during backward(to optimize memory cost) so that there's no gradient for gradient clipping

@cassanof
Copy link
Author

cassanof commented Jan 7, 2025

Hi! For your configuration, will that create a model_parts with len > 1? If not, then it will work correctly, otherwise it should break.
This is because python closures capture by reference to the identifier, not value. When optim_hook is invoked, it will always use the final state of optim_dict after the loop is complete, rather than the state at the time the hook was registered.

I have trouble in producing a repro because my version of torchtitan is heavily modified.

backward+optim would free gradient during backward(to optimize memory cost) so that there's no gradient for gradient clipping

That makes sense. Wouldn't you still want to clip the grad before stepping the optimizer? I understand you can't do grad norm clipping, because you can't compute the norm, but you can use traditional clipping where you clamp the grad by some threshold.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants