Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Saving should be aware of the dp mesh #799

Open
carmocca opened this issue Jan 21, 2025 · 0 comments
Open

Saving should be aware of the dp mesh #799

carmocca opened this issue Jan 21, 2025 · 0 comments

Comments

@carmocca
Copy link
Contributor

carmocca commented Jan 21, 2025

The docs for dcp.save say:

When saving checkpoint for FSDP’s ShardingStrategy.HYBRID_SHARD, only one of the shard_group should be calling save_state_dict and the corresponding process group needs to be passed in.

It references the old FSDP1's sharding strategy, however, this should apply equally to FSDP2. I believe this also applies to the different async saving flavors that are implemented.

Today, torchtitan does not do this:

elif self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
self._async_with_pinned_memory(checkpoint_id)
elif self.async_mode == AsyncMode.ASYNC:
self.async_future = dcp.async_save(
self.states, checkpoint_id=checkpoint_id, process_group=self.pg
)
else:
dcp.save(self.states, checkpoint_id=checkpoint_id)

Do you agree that it should? Have you tried this with FSDP2? Do you know of any blockers?
This would decrease the saving burden considerably for jobs with a large data_parallel_replicate_degree

Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant