- Notifications
You must be signed in to change notification settings - Fork 358
/
Copy pathutils.py
118 lines (97 loc) · 4.12 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""Utility functions used throughout Megatron core"""
fromfunctoolsimportreduce
importoperator
importtorch
defensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assertnumerator%denominator==0, "{} is not divisible by {}".format(
numerator, denominator
)
defdivide(numerator, denominator):
"""Ensure that numerator is divisible by the denominator and return
the division value."""
ensure_divisibility(numerator, denominator)
returnnumerator//denominator
classGlobalMemoryBuffer:
"""Global buffer to avoid dynamic memory allocations.
Caller should ensure that buffers of the same name
are not used concurrently."""
def__init__(self):
self.buffer= {}
defget_tensor(self, tensor_shape, dtype, name):
required_len=reduce(operator.mul, tensor_shape, 1)
ifself.buffer.get((name, dtype), None) isNoneor \
self.buffer[(name, dtype)].numel() <required_len:
self.buffer[(name, dtype)] = \
torch.empty(required_len,
dtype=dtype,
device=torch.cuda.current_device(),
requires_grad=False)
returnself.buffer[(name, dtype)][0:required_len].view(*tensor_shape)
def_kernel_make_viewless_tensor(inp, requires_grad):
'''Make a viewless tensor.
View tensors have the undesirable side-affect of retaining a reference
to the originally-viewed tensor, even after manually setting the '.data'
field. This method creates a new tensor that links to the old tensor's
data, without linking the viewed tensor, referenced via the '._base'
field.
'''
out=torch.empty(
(1,),
dtype=inp.dtype,
device=inp.device,
requires_grad=requires_grad,
)
out.data=inp.data
returnout
classMakeViewlessTensor(torch.autograd.Function):
'''
Autograd function to make a viewless tensor.
This function should be used in cases where the computation graph needs
to be propagated, but we only want a viewless tensor (e.g.,
ParallelTransformer's hidden_states). Call this function by passing
'keep_graph = True' to 'make_viewless_tensor()'.
'''
@staticmethod
defforward(ctx, inp, requires_grad):
return_kernel_make_viewless_tensor(inp, requires_grad)
@staticmethod
defbackward(ctx, grad_output):
returngrad_output, None
defmake_viewless_tensor(inp, requires_grad, keep_graph):
'''
Entry-point for creating viewless tensors.
This method should be used, rather than calling 'MakeViewlessTensor'
or '_kernel_make_viewless_tensor' directly. This method acts as a
switch for determining if an autograd function or a regular method
should be used to create the tensor.
'''
# return tensor as-is, if not a 'view'
ifinp._baseisNone:
returninp
# create viewless tensor
ifkeep_graph:
returnMakeViewlessTensor.apply(inp, requires_grad)
else:
return_kernel_make_viewless_tensor(inp, requires_grad)
defassert_viewless_tensor(tensor, extra_msg=None):
'''Assert that a tensor is not a view (i.e., its '._base' field is
not set).'''
ifisinstance(tensor, list):
[ assert_viewless_tensor(t) fortintensor ]
returntensor
ifnotisinstance(tensor, torch.Tensor):
returntensor
asserttensor._baseisNone, (
"Ensure tensor._base is None before setting tensor.data or storing "
"tensor to memory buffer. Otherwise, a memory leak will occur (and "
"likely accumulate over iterations). %s"
) %extra_msg
returntensor
defsafely_set_viewless_tensor_data(tensor, new_data_tensor):
'''Safely set tensor's '.data' field.
Check first that the tensor is viewless (i.e., '._base' not set). If not,
raise an exception.
'''
assert_viewless_tensor(tensor, extra_msg="FYI, tensor._base has shape %s, and new_data_tensor has shape %s."% ("--"iftensor._baseisNoneelsetensor._base.shape, new_data_tensor.shape))
tensor.data=new_data_tensor