diff --git a/src/transformers/commands/pt_to_tf.py b/src/transformers/commands/pt_to_tf.py deleted file mode 100644 index ad0dbd14e15b56..00000000000000 --- a/src/transformers/commands/pt_to_tf.py +++ /dev/null @@ -1,137 +0,0 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -from argparse import ArgumentParser, Namespace - -from ..utils import logging -from . import BaseTransformersCLICommand - - -MAX_ERROR = 5e-5 # larger error tolerance than in our internal tests, to avoid flaky user-facing errors - - -def convert_command_factory(args: Namespace): - """ - Factory function used to convert a model PyTorch checkpoint in a TensorFlow 2 checkpoint. - - Returns: ServeCommand - """ - return PTtoTFCommand( - args.model_name, - args.local_dir, - args.max_error, - args.new_weights, - args.no_pr, - args.push, - args.extra_commit_description, - args.override_model_class, - ) - - -class PTtoTFCommand(BaseTransformersCLICommand): - @staticmethod - def register_subcommand(parser: ArgumentParser): - """ - Register this command to argparse so it's available for the transformer-cli - - Args: - parser: Root parser to register command-specific arguments - """ - train_parser = parser.add_parser( - "pt-to-tf", - help=( - "CLI tool to run convert a transformers model from a PyTorch checkpoint to a TensorFlow checkpoint." - " Can also be used to validate existing weights without opening PRs, with --no-pr." - ), - ) - train_parser.add_argument( - "--model-name", - type=str, - required=True, - help="The model name, including owner/organization, as seen on the hub.", - ) - train_parser.add_argument( - "--local-dir", - type=str, - default="", - help="Optional local directory of the model repository. Defaults to /tmp/{model_name}", - ) - train_parser.add_argument( - "--max-error", - type=float, - default=MAX_ERROR, - help=( - f"Maximum error tolerance. Defaults to {MAX_ERROR}. This flag should be avoided, use at your own risk." - ), - ) - train_parser.add_argument( - "--new-weights", - action="store_true", - help="Optional flag to create new TensorFlow weights, even if they already exist.", - ) - train_parser.add_argument( - "--no-pr", action="store_true", help="Optional flag to NOT open a PR with converted weights." - ) - train_parser.add_argument( - "--push", - action="store_true", - help="Optional flag to push the weights directly to `main` (requires permissions)", - ) - train_parser.add_argument( - "--extra-commit-description", - type=str, - default="", - help="Optional additional commit description to use when opening a PR (e.g. to tag the owner).", - ) - train_parser.add_argument( - "--override-model-class", - type=str, - default=None, - help="If you think you know better than the auto-detector, you can specify the model class here. " - "Can be either an AutoModel class or a specific model class like BertForSequenceClassification.", - ) - train_parser.set_defaults(func=convert_command_factory) - - def __init__( - self, - model_name: str, - local_dir: str, - max_error: float, - new_weights: bool, - no_pr: bool, - push: bool, - extra_commit_description: str, - override_model_class: str, - *args, - ): - self._logger = logging.get_logger("transformers-cli/pt_to_tf") - self._model_name = model_name - self._local_dir = local_dir if local_dir else os.path.join("/tmp", model_name) - self._max_error = max_error - self._new_weights = new_weights - self._no_pr = no_pr - self._push = push - self._extra_commit_description = extra_commit_description - self._override_model_class = override_model_class - - def run(self): - # TODO (joao): delete file in v4.47 - raise NotImplementedError( - "\n\nConverting PyTorch weights to TensorFlow weights was removed in v4.43. " - "Instead, we recommend that you convert PyTorch weights to Safetensors, an improved " - "format that can be loaded by any framework, including TensorFlow. For more information, " - "please see the Safetensors conversion guide: " - "https://huggingface.co/docs/safetensors/en/convert-weights\n\n" - ) diff --git a/src/transformers/commands/transformers_cli.py b/src/transformers/commands/transformers_cli.py index 6e8cfea0c3141a..8e25eb6e13fe87 100644 --- a/src/transformers/commands/transformers_cli.py +++ b/src/transformers/commands/transformers_cli.py @@ -20,7 +20,6 @@ from .download import DownloadCommand from .env import EnvironmentCommand from .lfs import LfsCommands -from .pt_to_tf import PTtoTFCommand from .run import RunCommand from .serving import ServeCommand from .user import UserCommands @@ -39,7 +38,6 @@ def main(): UserCommands.register_subcommand(commands_parser) AddNewModelLikeCommand.register_subcommand(commands_parser) LfsCommands.register_subcommand(commands_parser) - PTtoTFCommand.register_subcommand(commands_parser) # Let's go args = parser.parse_args() diff --git a/utils/not_doctested.txt b/utils/not_doctested.txt index 232eed95b9dd04..0d05b105a4dc57 100644 --- a/utils/not_doctested.txt +++ b/utils/not_doctested.txt @@ -351,7 +351,6 @@ src/transformers/commands/convert.py src/transformers/commands/download.py src/transformers/commands/env.py src/transformers/commands/lfs.py -src/transformers/commands/pt_to_tf.py src/transformers/commands/run.py src/transformers/commands/serving.py src/transformers/commands/train.py