-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbuffer.py
More file actions
86 lines (67 loc) · 3.05 KB
/
buffer.py
File metadata and controls
86 lines (67 loc) · 3.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import math
from typing import List, Optional
import torch
import torch.distributed as dist
class Buffer:
"""Contiguous storage with sharding and collective communication.
Flattens a list of tensors into a single contiguous 1-D buffer, padded
to be evenly divisible by world_size. Maintains per-tensor views into
the buffer and a permanent local shard for reduce-scatter / all-gather.
This class is agnostic to what the tensors represent (parameters,
gradients, or anything else). Higher layers decide how to use it.
Args:
tensors: Tensors whose shapes define the layout. Only shapes and
numels are read; data is not copied.
rank: Local rank in the distributed group.
world_size: Total number of ranks.
dtype: Element type for the buffer.
device: Target CUDA device. Defaults to cuda:<rank>.
"""
def __init__(
self,
tensors: List[torch.Tensor],
rank: int,
world_size: int,
dtype: torch.dtype = torch.bfloat16,
device: Optional[torch.device] = None,
):
self.rank = rank
self.world_size = world_size
self.dtype = dtype
self.device = device or torch.device("cuda", rank)
numels = [t.numel() for t in tensors]
total_numel = sum(numels)
self.padded_numel = math.ceil(total_numel / world_size) * world_size
self.shard_numel = self.padded_numel // world_size
self._padding = self.padded_numel - total_numel
self._data = torch.zeros(self.padded_numel, dtype=dtype, device=self.device)
shard_start = rank * self.shard_numel
shard_end = shard_start + self.shard_numel
self._local_shard = self._data[shard_start:shard_end]
self._views: List[torch.Tensor] = []
offset = 0
for t, numel in zip(tensors, numels):
view = self._data[offset : offset + numel].view(t.shape)
self._views.append(view)
offset += numel
def get_views(self) -> List[torch.Tensor]:
"""Return pre-created views into the full buffer, one per input tensor."""
return self._views
def get_local_shard(self) -> torch.Tensor:
"""Return the permanent local shard tensor (shard_numel elements)."""
return self._local_shard
def get_data(self) -> torch.Tensor:
"""Return the full contiguous 1-D buffer."""
return self._data
def reduce_scatter(self, group: Optional[dist.ProcessGroup] = None) -> None:
"""Reduce-scatter the full buffer (SUM) into the local shard.
Args:
group: Process group for the collective. None uses the default.
"""
dist.reduce_scatter_tensor(self._local_shard, self._data, group=group)
def all_gather(self, group: Optional[dist.ProcessGroup] = None) -> None:
"""All-gather local shards from every rank into the full buffer.
Args:
group: Process group for the collective. None uses the default.
"""
dist.all_gather_into_tensor(self._data, self._local_shard, group=group)