-
Notifications
You must be signed in to change notification settings - Fork 251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
should we have an extension point for model transforms out of tree? #790
Comments
cc @awgu , @tianyu-l , @weifengpy |
@vkuzo more than happy to work together on this! Before we explore solutions, may I ask questions to better understand your requests
|
Thanks @vkuzo for opening this issue, I wanted to actually raise the same question! @tianyu-l We are very interested in having a similar feature. In our research team, some of our projects use TorchTitan as a git submodule instead of forking/copying the code (making the upgrade to latest We've implemented a solution internally, and I would be very happy to open a PR for it. In the big lines, we define a simple general class ModelHandler(Protocol):
def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
...
def convert(self, model: nn.Module):
...
def pre_optimizer_hook(self, model: Union[nn.Module, List[nn.Module]]):
...
def post_optimizer_hook(self, model: Union[nn.Module, List[nn.Module]]):
...
def register_model_handler(...):
pass This interface generalizes the Then in the YAML config, you can define a list of model handlers you want to apply, with every handler having its own (optional) parameters (passed to [model]
name = "llama3"
flavor = "3B"
handlers = ["Float8Handler", "FusedCrossEntropyLossHandler"]
[float8]
...
[fused_cross_entropy_loss]
... The collection of model handlers is applied sequentially to the model, and similarly pre/post optimizer hooks are called one after the other. Our feeling is that this strikes a good balance of keeping TorchTitan codebase simple, but allowing more easily users to incorporate their own training logic. |
@balancap I have a question though: another general question: cc: @vkuzo |
In torchao, we have various low precision training features which are in prototype: MX, int8, bitnet. While we expect most of these to eventually end up in the main torchao APIs, it often takes ~months for a prototype to graduate.
torchtitan is extremely useful for helping us test low precision prototypes in real-world settings. For now, we've been creating unlanded PRs to test functionality (examples: #614, #778). Would torchtitan consider building an extension point to support this kind of experimentation fully out-of-tree?
An example of how this could look like:
I'm not entirely sure on how this hook would be implemented since the current interface of torchtitan is CLI based, but wanted to share the request and start the discussion.
The text was updated successfully, but these errors were encountered: