Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

should we have an extension point for model transforms out of tree? #790

Open
vkuzo opened this issue Jan 15, 2025 · 4 comments
Open

should we have an extension point for model transforms out of tree? #790

vkuzo opened this issue Jan 15, 2025 · 4 comments
Assignees
Labels
enhancement New feature or request

Comments

@vkuzo
Copy link
Contributor

vkuzo commented Jan 15, 2025

In torchao, we have various low precision training features which are in prototype: MX, int8, bitnet. While we expect most of these to eventually end up in the main torchao APIs, it often takes ~months for a prototype to graduate.

torchtitan is extremely useful for helping us test low precision prototypes in real-world settings. For now, we've been creating unlanded PRs to test functionality (examples: #614, #778). Would torchtitan consider building an extension point to support this kind of experimentation fully out-of-tree?

An example of how this could look like:

  1. torchtitan provides a "model transformation" hook that it calls at a specified point in the initialization stage (for quantization, that should be after model init and before parallelization / torch.compile)
  2. user can provide a custom pass to transform the model (such as a prototype low precision training conversion pass)

I'm not entirely sure on how this hook would be implemented since the current interface of torchtitan is CLI based, but wanted to share the request and start the discussion.

@vkuzo
Copy link
Contributor Author

vkuzo commented Jan 15, 2025

cc @awgu , @tianyu-l , @weifengpy

@tianyu-l
Copy link
Contributor

tianyu-l commented Jan 16, 2025

@vkuzo more than happy to work together on this!

Before we explore solutions, may I ask questions to better understand your requests

  • what would you like to achieve ideally, e.g. do you want to use torchtitan as a library without cloning it? or more specifically demonstrate your work in torchao by importing torchtitan as a library?
  • (maybe repeating the question above) what are the "pain points" you are experiencing in the current ways (PR / branch / fork)

@tianyu-l tianyu-l added the enhancement New feature or request label Jan 16, 2025
@tianyu-l tianyu-l self-assigned this Jan 16, 2025
@balancap
Copy link

balancap commented Jan 17, 2025

Thanks @vkuzo for opening this issue, I wanted to actually raise the same question!

@tianyu-l We are very interested in having a similar feature. In our research team, some of our projects use TorchTitan as a git submodule instead of forking/copying the code (making the upgrade to latest main much cleaner and easier). But for that usecase, it is very useful to have clean/simple entry point for modifying models, optimizers, ...

We've implemented a solution internally, and I would be very happy to open a PR for it. In the big lines, we define a simple general ModelHandler protocol with a registry:

class ModelHandler(Protocol):
    def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
        ...

    def convert(self, model: nn.Module):
        ...

    def pre_optimizer_hook(self, model: Union[nn.Module, List[nn.Module]]):
        ...

    def post_optimizer_hook(self, model: Union[nn.Module, List[nn.Module]]):
        ...

def register_model_handler(...):
     pass

This interface generalizes the Float8Handler, can be used for any quantization handler, but also model modification (e.g. FlexAttention, custom fused kernels https://github.com/linkedin/Liger-Kernel/, ...)

Then in the YAML config, you can define a list of model handlers you want to apply, with every handler having its own (optional) parameters (passed to __init__ in the JobConfig):

[model]
name = "llama3"
flavor = "3B"
handlers = ["Float8Handler", "FusedCrossEntropyLossHandler"]

[float8]
...

[fused_cross_entropy_loss]
...

The collection of model handlers is applied sequentially to the model, and similarly pre/post optimizer hooks are called one after the other.

Our feeling is that this strikes a good balance of keeping TorchTitan codebase simple, but allowing more easily users to incorporate their own training logic.

@tianyu-l
Copy link
Contributor

@balancap
Thank you very much for the suggestions! Generalizing the Float8Handler sounds a reasonable thing to do.

I have a question though:
I'm assuming convert would be applied before the parallelization & compilation code. It sounds viable if the underlying change in convert stay compatible with those code. E.g. one might want to modify a (llama) model in a "parallelization-breaking" way, or register a completely new model (see #282). In such cases, do you think it's better to only support the forking way, not the library/submodule way?

another general question:
Could you share that in your case, what parts of torchtitan you'd keep and what parts you'd modify?
E.g. keeping the model, data loading, parallelization; modifying the dataset, optimizer / lr scheduler, FlexAttention / kernels

cc: @vkuzo

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants