Shortcuts

Source code for mmeval.core.dist_backends.torch_cuda

# Copyright (c) OpenMMLab. All rights reserved.

from typing import TYPE_CHECKING, Any, Tuple, TypeVar, Union

from mmeval.utils import try_import
from .torch_cpu import TorchCPUDist

if TYPE_CHECKING:
    import torch
else:
    torch = try_import('torch')

Tensor = TypeVar('Tensor', bound='torch.Tensor')


[docs]class TorchCUDADist(TorchCPUDist): """A cuda distributed communication backend for torch.distributed.""" def __init__(self) -> None: super().__init__() if torch is None: raise ImportError(f'For availability of {self.__class__.__name__},' ' please install pytorch first.') if not torch.distributed.is_nccl_available(): raise RuntimeError( f'For availability of {self.__class__.__name__},' ' make sure torch.distributed.is_nccl_available().') def _object_to_tensor(self, obj: Any) -> Tuple[Tensor, Tensor]: """Convert the given object to a cuda tensor via `pickle.dumps`. Args: obj (any): Any pickle-able python object. Returns: tuple: A tuple of the tensor converted from given object and the tensor size. """ # Add type annotation make mypy happy obj_tensor: Tensor obj_size_tensor: Tensor obj_tensor, obj_size_tensor = super()._object_to_tensor(obj) return obj_tensor.cuda(), obj_size_tensor.cuda() def _tensor_to_object(self, tensor: Tensor, tensor_size: Union[int, Tensor]) -> Any: """Convert the given cuda tensor to a object via `pickle.loads`. Args: tenosr (Tensor): A cuda tensor. tensor_size (int or Tensor): The tensor size of the given Tensor to be convert object. Returns: Any: The object converted from the given cuda tensor. """ return super()._tensor_to_object(tensor.detach().cpu(), tensor_size)
Read the Docs v: latest
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.