-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Expand file tree
/
Copy pathutils.py
More file actions
210 lines (166 loc) · 6.36 KB
/
utils.py
File metadata and controls
210 lines (166 loc) · 6.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import math
import os
from typing import Iterable, List
logger = logging.getLogger(__name__)
import einops
import torch
import torch.nn as nn
__all__ = ['if_exist', '_compute_softmax', 'flatten']
activation_registry = {
"identity": nn.Identity,
"hardtanh": nn.Hardtanh,
"relu": nn.ReLU,
"selu": nn.SELU,
"swish": nn.SiLU,
"silu": nn.SiLU,
"gelu": nn.GELU,
}
def if_exist(outfold: str, files: List[str]):
"""
Returns true if all given files exist in the given folder
Args:
outfold: folder path
files: list of file names relative to outfold
"""
if not os.path.exists(outfold):
return False
for file in files:
if not os.path.exists(f'{outfold}/{file}'):
return False
return True
def _compute_softmax(scores):
"""Compute softmax probability over raw logits."""
if not scores:
return []
max_score = None
for score in scores:
if max_score is None or score > max_score:
max_score = score
exp_scores = []
total_sum = 0.0
for score in scores:
x = math.exp(score - max_score)
exp_scores.append(x)
total_sum += x
probs = []
for score in exp_scores:
probs.append(score / total_sum)
return probs
def flatten_iterable(iter: Iterable) -> Iterable:
"""Flatten an iterable which contains values or
iterables with values.
Args:
iter: iterable containing values at the deepest level.
Returns:
A flat iterable containing values.
"""
for it in iter:
if isinstance(it, str) or not isinstance(it, Iterable):
yield it
else:
yield from flatten_iterable(it)
def flatten(list_in: List) -> List:
"""Flatten a list of (nested lists of) values into a flat list.
Args:
list_in: list of values, possibly nested
Returns:
A flat list of values.
"""
return list(flatten_iterable(list_in))
def extend_instance(obj, mixin):
"""Apply mixins to a class instance after creation"""
base_cls = obj.__class__
base_cls_name = obj.__class__.__name__
obj.__class__ = type(
base_cls_name, (mixin, base_cls), {}
) # mixin needs to go first for our forward() logic to work
def apply_rope_scaling(freqs, scale_factor=8, low_freq_factor=1, high_freq_factor=4, old_context_len=8192):
# Apply scaling for RoPE frequencies
logger.info("apply rope scaling ...")
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
new_freqs = []
for freq in freqs:
wavelen = 2 * math.pi / freq
if wavelen < high_freq_wavelen:
new_freqs.append(freq)
elif wavelen > low_freq_wavelen:
new_freqs.append(freq / scale_factor)
else:
assert low_freq_wavelen != high_freq_wavelen
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
def mask_sequence_tensor(tensor: torch.Tensor, lengths: torch.Tensor):
"""
For tensors containing sequences, zero out out-of-bound elements given lengths of every element in the batch.
tensor: tensor of shape (B, L), (B, D, L) or (B, D1, D2, L),
lengths: LongTensor of shape (B,)
"""
batch_size, *_, max_lengths = tensor.shape
if len(tensor.shape) == 2:
mask = torch.ones(batch_size, max_lengths).cumsum(dim=-1).type_as(lengths)
mask = mask <= einops.rearrange(lengths, 'B -> B 1')
elif len(tensor.shape) == 3:
mask = torch.ones(batch_size, 1, max_lengths).cumsum(dim=-1).type_as(lengths)
mask = mask <= einops.rearrange(lengths, 'B -> B 1 1')
elif len(tensor.shape) == 4:
mask = torch.ones(batch_size, 1, 1, max_lengths).cumsum(dim=-1).type_as(lengths)
mask = mask <= einops.rearrange(lengths, 'B -> B 1 1 1')
else:
raise ValueError('Can only mask tensors of shape B x L, B x D x L and B x D1 x D2 x L')
return tensor * mask
class ClampActivation(nn.Module):
def __init__(self, min_value: float = -1.0, max_value: float = 1.0):
super().__init__()
self.min_value = min_value
self.max_value = max_value
def forward(self, input: torch.Tensor) -> torch.Tensor:
return torch.clamp(input, min=self.min_value, max=self.max_value)
@torch.jit.script
def snake(x: torch.Tensor, alpha: torch.Tensor, eps: float = 1e-9) -> torch.Tensor:
"""
equation for snake activation function: x + (alpha + eps)^-1 * sin(alpha * x)^2
"""
shape = x.shape
x = x.reshape(shape[0], shape[1], -1)
x = x + (alpha + eps).reciprocal() * torch.sin(alpha * x).pow(2)
x = x.reshape(shape)
return x
class Snake(nn.Module):
"""
Snake activation function introduced in 'https://arxiv.org/abs/2006.08195'
"""
def __init__(self, channels: int):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return snake(x, self.alpha)
class HalfSnake(nn.Module):
"""
Activation which applies snake to the first half of input elements and leaky relu to the second half.
"""
def __init__(self, channels: int):
super().__init__()
self.snake_channels = channels // 2
self.snake_act = Snake(self.snake_channels)
self.lrelu = torch.nn.LeakyReLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
snake_out = self.snake_act(x[:, : self.snake_channels, :])
lrelu_out = self.lrelu(x[:, self.snake_channels :, :])
out = torch.cat([snake_out, lrelu_out], dim=1)
return out