Sifal commited on
Commit
8f74c7f
·
verified ·
1 Parent(s): f3edc31

Create bert_padding.py

Browse files
Files changed (1) hide show
  1. bert_padding.py +159 -0
bert_padding.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML Examples authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Adapted from https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
5
+ # Which was adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
6
+
7
+ """Helper functions for padding and unpadding batches.
8
+
9
+ These functions are used extensively throughout the Mosaic BERT implementation
10
+ in `bert_layers.py`.
11
+ """
12
+
13
+ from typing import Tuple, cast
14
+
15
+ import torch
16
+ import torch.nn.functional as F
17
+ from einops import rearrange, repeat
18
+
19
+
20
+ class IndexFirstAxis(torch.autograd.Function):
21
+
22
+ @staticmethod
23
+ def forward(ctx, input: torch.Tensor,
24
+ indices: torch.Tensor) -> torch.Tensor:
25
+ """Get just the values of `input` which are at `indices`.
26
+
27
+ Arguments:
28
+ ctx: the autograd context object
29
+ input: (b, ...) 2+ dimensional tensor
30
+ indices: (num_idx) 1D tensor
31
+ """
32
+ ctx.save_for_backward(indices)
33
+ assert input.ndim >= 2
34
+ ctx.first_axis_dim, other_shape = input.shape[0], input.shape[
35
+ 1:] # type: ignore
36
+ second_dim = other_shape.numel(
37
+ ) # product of sizes of all but first dimension
38
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
39
+ return torch.gather(
40
+ rearrange(input, 'b ... -> b (...)'), # (b, ...) -> (b, second_dim)
41
+ 0,
42
+ repeat(indices, 'z -> z d',
43
+ d=second_dim) # (indices,) -> (indices, second_dim)
44
+ ).reshape(-1, *other_shape) # (num_idx, ...)
45
+
46
+ @staticmethod
47
+ def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]:
48
+ indices, = ctx.saved_tensors
49
+ assert grad_output.ndim >= 2
50
+ other_shape = grad_output.shape[1:]
51
+ grad_output = rearrange(grad_output, 'b ... -> b (...)')
52
+ grad_input = torch.zeros([ctx.first_axis_dim, grad_output.shape[1]],
53
+ device=grad_output.device,
54
+ dtype=grad_output.dtype)
55
+ # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
56
+ # grad_input[indices] = grad_output
57
+ grad_input.scatter_(0,
58
+ repeat(indices, 'z -> z d', d=grad_output.shape[1]),
59
+ grad_output)
60
+ return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
61
+
62
+
63
+ index_first_axis = IndexFirstAxis.apply
64
+
65
+
66
+ class IndexPutFirstAxis(torch.autograd.Function):
67
+
68
+ @staticmethod
69
+ def forward(ctx, values: torch.Tensor, indices: torch.Tensor,
70
+ first_axis_dim) -> torch.Tensor:
71
+ ctx.save_for_backward(indices)
72
+ assert indices.ndim == 1
73
+ assert values.ndim >= 2
74
+ output = torch.zeros(first_axis_dim,
75
+ *values.shape[1:],
76
+ device=values.device,
77
+ dtype=values.dtype)
78
+ output[indices] = values
79
+ return output
80
+
81
+ @staticmethod
82
+ def backward(ctx,
83
+ grad_output: torch.Tensor) -> Tuple[torch.Tensor, None, None]:
84
+ indices, = ctx.saved_tensors
85
+ grad_values = grad_output[indices]
86
+ return grad_values, None, None
87
+
88
+
89
+ index_put_first_axis = IndexPutFirstAxis.apply
90
+
91
+
92
+ def unpad_input(
93
+ hidden_states: torch.Tensor,
94
+ attention_mask: torch.Tensor,
95
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
96
+ """Remove padding from input sequences.
97
+
98
+ Arguments:
99
+ hidden_states: (batch, seqlen, ...)
100
+ attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
101
+
102
+ Returns:
103
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
104
+ indices: (total_nnz)
105
+ cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
106
+ max_seqlen_in_batch: int ()
107
+ """
108
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
109
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
110
+ max_seqlen_in_batch = int(seqlens_in_batch.max().item())
111
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32),
112
+ (1, 0))
113
+ # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
114
+ # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
115
+ # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
116
+ # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
117
+ # so we write custom forward and backward to make it a bit faster.
118
+ hidden_states = cast(
119
+ torch.Tensor,
120
+ index_first_axis(rearrange(hidden_states, 'b s ... -> (b s) ...'),
121
+ indices))
122
+ return hidden_states, indices, cu_seqlens, max_seqlen_in_batch
123
+
124
+
125
+ def unpad_input_only(
126
+ hidden_states: torch.Tensor,
127
+ attention_mask: torch.Tensor,
128
+ ) -> torch.Tensor:
129
+ """Like unpad_input, but only return the unpadded first tensor.
130
+
131
+ Save a small amount of overhead.
132
+
133
+ Arguments:
134
+ hidden_states: (batch, seqlen, ...)
135
+ attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
136
+
137
+ Returns:
138
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
139
+ """
140
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
141
+ rearranged = rearrange(hidden_states, 'b s ... -> (b s) ...')
142
+ return index_first_axis(rearranged, indices) # type: ignore
143
+
144
+
145
+ def pad_input(hidden_states: torch.Tensor, indices: torch.Tensor, batch: int,
146
+ seqlen: int) -> torch.Tensor:
147
+ """Add padding to sequences.
148
+
149
+ Arguments:
150
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
151
+ indices: (total_nnz)
152
+ batch: int batch_size
153
+ seqlen: int max sequence length
154
+
155
+ Returns:
156
+ hidden_states: (batch, seqlen, ...)
157
+ """
158
+ output = index_put_first_axis(hidden_states, indices, batch * seqlen)
159
+ return rearrange(output, '(b s) ... -> b s ...', b=batch) # type: ignore