From 00aa086298bc10de9cc43a2261a0ca9868c482cb Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 12 Aug 2024 23:55:08 +0000 Subject: [PATCH] Revert "[dtensor] move tensor constructors to a separate module (#133129)" This reverts commit e890d888d916b4f38b383a59e0e9445513c67313. Reverted https://github.com/pytorch/pytorch/pull/133129 on behalf of https://github.com/fbgheith due to breaking internal tests ([comment](https://github.com/pytorch/pytorch/pull/133129#issuecomment-2285090400)) --- torch/distributed/_tensor/__init__.py | 355 +++++++++++++++++- .../_tensor/_tensor_constructors.py | 334 ---------------- torch/distributed/_tensor/debug/__init__.py | 3 +- torch/distributed/tensor/parallel/style.py | 2 +- 4 files changed, 340 insertions(+), 354 deletions(-) delete mode 100644 torch/distributed/_tensor/_tensor_constructors.py diff --git a/torch/distributed/_tensor/__init__.py b/torch/distributed/_tensor/__init__.py index 0f863529d4f0d..85de716d8439e 100644 --- a/torch/distributed/_tensor/__init__.py +++ b/torch/distributed/_tensor/__init__.py @@ -1,19 +1,21 @@ # mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates +from typing import Optional, Sequence +# Import all builtin dist tensor ops import torch -import torch.distributed._tensor.ops as _ops # force import all built-in dtensor ops -from torch.distributed._tensor._tensor_constructors import ( - empty, - full, - ones, - rand, - randn, - zeros, -) +import torch.distributed._tensor.ops +import torch.distributed._tensor.random as random +from torch.distributed._tensor._utils import compute_local_shape from torch.distributed._tensor.api import distribute_module, distribute_tensor, DTensor -from torch.distributed._tensor.placement_types import Partial, Replicate, Shard -from torch.distributed.device_mesh import DeviceMesh, init_device_mesh +from torch.distributed._tensor.ops.utils import normalize_to_torch_size +from torch.distributed._tensor.placement_types import ( + Partial, + Placement, + Replicate, + Shard, +) +from torch.distributed.device_mesh import _mesh_resources, DeviceMesh, init_device_mesh from torch.optim.optimizer import ( _foreach_supported_types as _optim_foreach_supported_types, ) @@ -32,12 +34,6 @@ "Shard", "Replicate", "Partial", - "ones", - "empty", - "full", - "rand", - "randn", - "zeros", ] @@ -48,3 +44,328 @@ if DTensor not in _util_foreach_supported_types: _util_foreach_supported_types.append(DTensor) + + +def _dtensor_init_helper( + init_op, + size: torch.Size, + device_mesh=None, + placements=None, + **kwargs, +) -> DTensor: + from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta + + # if device_mesh is None, use the one from mesh resources + device_mesh = device_mesh or _mesh_resources.get_current_mesh() + kwargs["device"] = device_mesh.device_type + + # set default placements to replicated if not specified + placements = placements or tuple(Replicate() for _ in range(device_mesh.ndim)) + + # check device_mesh againts placements + assert device_mesh.ndim == len( + placements + ), "mesh dimension does not match the length of placements" + + assert kwargs["layout"] == torch.strided, "layout value not supported!" + torch_stride = torch._prims_common.make_contiguous_strides_for(size) + + # get local tensor shape + local_shape = compute_local_shape(size, device_mesh, placements) + # initialize the local tensor + if init_op == torch.full: + fill_value = kwargs.pop("fill_value", 0) + local_tensor = init_op(local_shape, fill_value, **kwargs) + elif init_op == torch.rand or init_op == torch.randn: + # this tensor meta is not used except `shape` + dtype = kwargs.get("dtype", torch.get_default_dtype()) + + tensor_meta = TensorMeta(size, (0,), dtype) + spec = DTensorSpec(device_mesh, placements, tensor_meta=tensor_meta) + + if random.is_rng_supported_mesh(device_mesh) and not random._rng_tracker: + random._rng_tracker = random.OffsetBasedRNGTracker() + + assert random._rng_tracker is not None + with random._rng_tracker._distribute_region(spec): + local_tensor = init_op(local_shape, **kwargs) + else: + local_tensor = init_op(local_shape, **kwargs) + + spec = DTensorSpec( + device_mesh, + tuple(placements), + tensor_meta=TensorMeta( + size, + torch_stride, + local_tensor.dtype, + ), + ) + + return DTensor( + local_tensor, + spec, + requires_grad=kwargs["requires_grad"], + ) + + +def ones( + *size, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + requires_grad: bool = False, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, +) -> DTensor: + """ + Returns a :class:`DTensor` filled with the scalar value 1, with the shape defined + by the variable argument ``size``. + + Args: + size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. + Can be a variable number of arguments or a collection like a list or tuple. + E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..)) + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned DTensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned :class:`DTensor`. Default: ``False``. + device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks + placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` + + Returns: + A :class:`DTensor` object on each rank + """ + torch_size = normalize_to_torch_size(size) + + return _dtensor_init_helper( + torch.ones, + torch_size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + device_mesh=device_mesh, + placements=placements, + ) + + +def empty( + *size, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + requires_grad: bool = False, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, +) -> DTensor: + """ + Returns a :class:`DTensor` filled with uninitialized data. The shape of the :class:`DTensor` + is defined by the variable argument ``size``. + + Args: + size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. + Can be a variable number of arguments or a collection like a list or tuple. + E.g.: empty(1,2,3..) or empty([1,2,3..]) or empty((1,2,3..)) + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).\ + layout (:class:`torch.layout`, optional): the desired layout of returned :class:`DTensor`. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned :class:`DTensor`. Default: ``False``. + device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks + placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` + + Returns: + A :class:`DTensor` object on each rank + """ + torch_size = normalize_to_torch_size(size) + + return _dtensor_init_helper( + torch.empty, + torch_size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + device_mesh=device_mesh, + placements=placements, + ) + + +def full( + size, + fill_value, + *, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + requires_grad: bool = False, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, +) -> DTensor: + """ + Returns a :class:`DTensor` filled with ``fill_value``. The scalar value type should match + ``device_mesh.device_type``. + + Args: + size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. + Can be a variable number of arguments or a collection like a list or tuple. + E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..)) + fill_value(Scalar): the value to fill the output tensor with. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned DTensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned :class:`DTensor`. Default: ``False``. + device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks. + placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` + + Returns: + A :class:`DTensor` object on each rank + """ + torch_size = normalize_to_torch_size(size) + + return _dtensor_init_helper( + torch.full, + torch_size, + fill_value=fill_value, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + device_mesh=device_mesh, + placements=placements, + ) + + +def rand( + *size, + requires_grad: bool = False, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, +) -> DTensor: + """ + Returns a :class:`DTensor` filled with random numbers from a uniform distribution + on the interval ``[0, 1)``. The shape of the tensor is defined by the variable + argument ``size``. + + Args: + size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. + Can be a variable number of arguments or a collection like a list or tuple. + E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..)) + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned DTensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned :class:`DTensor`. Default: ``False``. + device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks. + placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` + + Returns: + A :class:`DTensor` object on each rank + """ + torch_size = normalize_to_torch_size(size) + + return _dtensor_init_helper( + torch.rand, + torch_size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + device_mesh=device_mesh, + placements=placements, + ) + + +def randn( + *size, + requires_grad: bool = False, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, +) -> DTensor: + """ + Returns a :class:`DTensor` filled with random numbers from a normal distribution + with mean 0 and variance 1. The shape of the tensor is defined by the variable + argument ``size``. + + Args: + size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. + Can be a variable number of arguments or a collection like a list or tuple. + E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..)) + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned DTensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned :class:`DTensor`. Default: ``False``. + device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks. + placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` + + Returns: + A :class:`DTensor` object on each rank + """ + torch_size = normalize_to_torch_size(size) + + return _dtensor_init_helper( + torch.randn, + torch_size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + device_mesh=device_mesh, + placements=placements, + ) + + +def zeros( + *size, + requires_grad: bool = False, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, +) -> DTensor: + """ + Returns a :class:`DTensor` filled with the scalar value 0. + + Args: + size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. + Can be a variable number of arguments or a collection like a list or tuple. + E.g.: zeros(1,2,3..) or zeros([1,2,3..]) or zeros((1,2,3..)) + Keyword args: + requires_grad (bool, optional): If autograd should record operations on the + returned :class:`DTensor`. Default: ``False``. + dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned :class:`DTensor`. + Default: ``torch.strided``. + device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks + placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` + + Returns: + A :class:`DTensor` object on each rank + """ + torch_size = normalize_to_torch_size(size) + + return _dtensor_init_helper( + torch.zeros, + torch_size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + device_mesh=device_mesh, + placements=placements, + ) diff --git a/torch/distributed/_tensor/_tensor_constructors.py b/torch/distributed/_tensor/_tensor_constructors.py deleted file mode 100644 index 03b8d88791926..0000000000000 --- a/torch/distributed/_tensor/_tensor_constructors.py +++ /dev/null @@ -1,334 +0,0 @@ -from typing import Optional, Sequence - -import torch -import torch.distributed._tensor.random as random -from torch.distributed._tensor._utils import compute_local_shape -from torch.distributed._tensor.api import DTensor -from torch.distributed._tensor.ops.utils import normalize_to_torch_size -from torch.distributed._tensor.placement_types import Placement, Replicate -from torch.distributed.device_mesh import _mesh_resources, DeviceMesh - - -def _dtensor_init_helper( # type: ignore[no-untyped-def] - init_op, - size: torch.Size, - device_mesh: Optional[DeviceMesh] = None, - placements: Optional[Sequence[Placement]] = None, - **kwargs, -) -> DTensor: - from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta - - # if device_mesh is None, use the one from mesh resources - device_mesh = device_mesh or _mesh_resources.get_current_mesh() - kwargs["device"] = device_mesh.device_type - - # set default placements to replicated if not specified - placements = placements or tuple(Replicate() for _ in range(device_mesh.ndim)) - - # check device_mesh againts placements - assert device_mesh.ndim == len( - placements - ), "mesh dimension does not match the length of placements" - - assert kwargs["layout"] == torch.strided, "layout value not supported!" - torch_stride = torch._prims_common.make_contiguous_strides_for(size) - - # get local tensor shape - local_shape = compute_local_shape(size, device_mesh, placements) - # initialize the local tensor - if init_op == torch.full: - fill_value = kwargs.pop("fill_value", 0) - local_tensor = init_op(local_shape, fill_value, **kwargs) - elif init_op == torch.rand or init_op == torch.randn: - # this tensor meta is not used except `shape` - dtype = kwargs.get("dtype", torch.get_default_dtype()) - - tensor_meta = TensorMeta(size, (0,), dtype) - spec = DTensorSpec(device_mesh, tuple(placements), tensor_meta=tensor_meta) - - if random.is_rng_supported_mesh(device_mesh) and not random._rng_tracker: - random._rng_tracker = random.OffsetBasedRNGTracker() - - assert random._rng_tracker is not None - with random._rng_tracker._distribute_region(spec): - local_tensor = init_op(local_shape, **kwargs) - else: - local_tensor = init_op(local_shape, **kwargs) - - spec = DTensorSpec( - device_mesh, - tuple(placements), - tensor_meta=TensorMeta( - size, - torch_stride, - local_tensor.dtype, - ), - ) - - return DTensor( - local_tensor, - spec, - requires_grad=kwargs["requires_grad"], - ) - - -def ones( # type: ignore[no-untyped-def] - *size, - dtype: Optional[torch.dtype] = None, - layout: torch.layout = torch.strided, - requires_grad: bool = False, - device_mesh: Optional[DeviceMesh] = None, - placements: Optional[Sequence[Placement]] = None, -) -> DTensor: - """ - Returns a :class:`DTensor` filled with the scalar value 1, with the shape defined - by the variable argument ``size``. - - Args: - size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. - Can be a variable number of arguments or a collection like a list or tuple. - E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..)) - - Keyword args: - dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. - Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). - layout (:class:`torch.layout`, optional): the desired layout of returned DTensor. - Default: ``torch.strided``. - requires_grad (bool, optional): If autograd should record operations on the - returned :class:`DTensor`. Default: ``False``. - device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks - placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` - - Returns: - A :class:`DTensor` object on each rank - """ - torch_size = normalize_to_torch_size(size) - - return _dtensor_init_helper( - torch.ones, - torch_size, - dtype=dtype, - layout=layout, - requires_grad=requires_grad, - device_mesh=device_mesh, - placements=placements, - ) - - -def empty( # type: ignore[no-untyped-def] - *size, - dtype: Optional[torch.dtype] = None, - layout: torch.layout = torch.strided, - requires_grad: bool = False, - device_mesh: Optional[DeviceMesh] = None, - placements: Optional[Sequence[Placement]] = None, -) -> DTensor: - """ - Returns a :class:`DTensor` filled with uninitialized data. The shape of the :class:`DTensor` - is defined by the variable argument ``size``. - - Args: - size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. - Can be a variable number of arguments or a collection like a list or tuple. - E.g.: empty(1,2,3..) or empty([1,2,3..]) or empty((1,2,3..)) - - Keyword args: - dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. - Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).\ - layout (:class:`torch.layout`, optional): the desired layout of returned :class:`DTensor`. - Default: ``torch.strided``. - requires_grad (bool, optional): If autograd should record operations on the - returned :class:`DTensor`. Default: ``False``. - device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks - placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` - - Returns: - A :class:`DTensor` object on each rank - """ - torch_size = normalize_to_torch_size(size) - - return _dtensor_init_helper( - torch.empty, - torch_size, - dtype=dtype, - layout=layout, - requires_grad=requires_grad, - device_mesh=device_mesh, - placements=placements, - ) - - -def full( # type: ignore[no-untyped-def] - size, - fill_value, - *, - dtype: Optional[torch.dtype] = None, - layout: torch.layout = torch.strided, - requires_grad: bool = False, - device_mesh: Optional[DeviceMesh] = None, - placements: Optional[Sequence[Placement]] = None, -) -> DTensor: - """ - Returns a :class:`DTensor` filled with ``fill_value``. The scalar value type should match - ``device_mesh.device_type``. - - Args: - size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. - Can be a variable number of arguments or a collection like a list or tuple. - E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..)) - fill_value(Scalar): the value to fill the output tensor with. - - Keyword args: - dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. - Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). - layout (:class:`torch.layout`, optional): the desired layout of returned DTensor. - Default: ``torch.strided``. - requires_grad (bool, optional): If autograd should record operations on the - returned :class:`DTensor`. Default: ``False``. - device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks. - placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` - - Returns: - A :class:`DTensor` object on each rank - """ - torch_size = normalize_to_torch_size(size) - - return _dtensor_init_helper( - torch.full, - torch_size, - fill_value=fill_value, - dtype=dtype, - layout=layout, - requires_grad=requires_grad, - device_mesh=device_mesh, - placements=placements, - ) - - -def rand( # type: ignore[no-untyped-def] - *size, - requires_grad: bool = False, - dtype: Optional[torch.dtype] = None, - layout: torch.layout = torch.strided, - device_mesh: Optional[DeviceMesh] = None, - placements: Optional[Sequence[Placement]] = None, -) -> DTensor: - """ - Returns a :class:`DTensor` filled with random numbers from a uniform distribution - on the interval ``[0, 1)``. The shape of the tensor is defined by the variable - argument ``size``. - - Args: - size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. - Can be a variable number of arguments or a collection like a list or tuple. - E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..)) - - Keyword args: - dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. - Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). - layout (:class:`torch.layout`, optional): the desired layout of returned DTensor. - Default: ``torch.strided``. - requires_grad (bool, optional): If autograd should record operations on the - returned :class:`DTensor`. Default: ``False``. - device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks. - placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` - - Returns: - A :class:`DTensor` object on each rank - """ - torch_size = normalize_to_torch_size(size) - - return _dtensor_init_helper( - torch.rand, - torch_size, - dtype=dtype, - layout=layout, - requires_grad=requires_grad, - device_mesh=device_mesh, - placements=placements, - ) - - -def randn( # type: ignore[no-untyped-def] - *size, - requires_grad: bool = False, - dtype: Optional[torch.dtype] = None, - layout: torch.layout = torch.strided, - device_mesh: Optional[DeviceMesh] = None, - placements: Optional[Sequence[Placement]] = None, -) -> DTensor: - """ - Returns a :class:`DTensor` filled with random numbers from a normal distribution - with mean 0 and variance 1. The shape of the tensor is defined by the variable - argument ``size``. - - Args: - size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. - Can be a variable number of arguments or a collection like a list or tuple. - E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..)) - - Keyword args: - dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. - Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). - layout (:class:`torch.layout`, optional): the desired layout of returned DTensor. - Default: ``torch.strided``. - requires_grad (bool, optional): If autograd should record operations on the - returned :class:`DTensor`. Default: ``False``. - device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks. - placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` - - Returns: - A :class:`DTensor` object on each rank - """ - torch_size = normalize_to_torch_size(size) - - return _dtensor_init_helper( - torch.randn, - torch_size, - dtype=dtype, - layout=layout, - requires_grad=requires_grad, - device_mesh=device_mesh, - placements=placements, - ) - - -def zeros( # type: ignore[no-untyped-def] - *size, - requires_grad: bool = False, - dtype: Optional[torch.dtype] = None, - layout: torch.layout = torch.strided, - device_mesh: Optional[DeviceMesh] = None, - placements: Optional[Sequence[Placement]] = None, -) -> DTensor: - """ - Returns a :class:`DTensor` filled with the scalar value 0. - - Args: - size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. - Can be a variable number of arguments or a collection like a list or tuple. - E.g.: zeros(1,2,3..) or zeros([1,2,3..]) or zeros((1,2,3..)) - Keyword args: - requires_grad (bool, optional): If autograd should record operations on the - returned :class:`DTensor`. Default: ``False``. - dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. - Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). - layout (:class:`torch.layout`, optional): the desired layout of returned :class:`DTensor`. - Default: ``torch.strided``. - device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks - placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` - - Returns: - A :class:`DTensor` object on each rank - """ - torch_size = normalize_to_torch_size(size) - - return _dtensor_init_helper( - torch.zeros, - torch_size, - dtype=dtype, - layout=layout, - requires_grad=requires_grad, - device_mesh=device_mesh, - placements=placements, - ) diff --git a/torch/distributed/_tensor/debug/__init__.py b/torch/distributed/_tensor/debug/__init__.py index 7ad08c376028d..e18483ed30e55 100644 --- a/torch/distributed/_tensor/debug/__init__.py +++ b/torch/distributed/_tensor/debug/__init__.py @@ -1,4 +1,5 @@ # mypy: allow-untyped-defs +from torch.distributed._tensor.api import DTensor from torch.distributed._tensor.debug.comm_mode import CommDebugMode from torch.distributed._tensor.debug.visualize_sharding import visualize_sharding @@ -12,8 +13,6 @@ def _get_sharding_prop_cache_info(): This would return a named tuple showing hits, misses, maxsize and cursize of the sharding propagator cache. """ - from torch.distributed._tensor.api import DTensor - return ( DTensor._op_dispatcher.sharding_propagator.propagate_op_sharding.cache_info() # type:ignore[attr-defined] ) diff --git a/torch/distributed/tensor/parallel/style.py b/torch/distributed/tensor/parallel/style.py index 46b19a918296c..5c66584765316 100644 --- a/torch/distributed/tensor/parallel/style.py +++ b/torch/distributed/tensor/parallel/style.py @@ -11,10 +11,10 @@ distribute_module, distribute_tensor, DTensor, + Placement, Replicate, Shard, ) -from torch.distributed._tensor.placement_types import Placement __all__ = [