This repository was archived by the owner on Feb 25, 2022. It is now read-only.
- Notifications
You must be signed in to change notification settings - Fork 962
/
Copy pathsample.py
218 lines (190 loc) · 9.11 KB
/
sample.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
importmesh_tensorflowasmtf
importtensorflow.compat.v1astf
importmesh_tensorflow.transformerasmtf_transformer
frommodels.utilsimportentmax, sample_categorical
frommodels.gpt2importgpt2
defsample_autoregressive(partial_sequences,
other_features,
params,
stop_at_token=50256,
max_steps=None,
temperature=0.9,
variable_dtype=mtf.VariableDType(tf.float32),
encoder_output=None,
encoder_sequence_id=None,
encoder_inputs=None,
shared_params=None,
has_partial_sequences=True,
encoder_layer_outputs=None,
never_end=False,
remove_partial_sequences=False,
sampling_keep_top_k=-1,
sampling_use_entmax=False,
bos_id=50256,
):
"""Sample randomly one token at a time.
The partial_sequences represent partial sequences to be continued. The
first tokens of each sequence are nonzero representing the given partial
sequences and the last tokens of each sequence are zeros, representing what
needs to be filled in.
If there are no partial sequences (you want to sample from the beginning),
then pass partial_sequences=mtf.zeros(mesh, shape, dtype=tf.int32) and
has_partial_sequences=False (so we can skip computation).
Args:
partial_sequences: an int32 Tensor with shape [<batch_dims>, length_dim]
stop_at_token: an optional integer eos id. Stop when we produce it.
max_steps: an optional integer, the max number of steps to decode.
temperature: an optional floating point value between 0.0 and 1.0 0.0
means argmax, 1.0 means sample according to predicted distribution.
variable_dtype: a mtf.VariableDType
encoder_output: an optional Tensor
encoder_sequence_id: an optional Tensor
encoder_inputs: an optional Tensor
shared_params: an optional dictionary
has_partial_sequences: a boolean
encoder_layer_outputs: optional - readonly list of tensor activations when
decoding, one per each input layer + the embedding layer
never_end: a boolean - if set, then avoid generating stop_at_token
remove_partial_sequences: a boolean - whether to remove the partial
sequences from the output
sampling_keep_top_k: an integer - if not -1, only sample from the top k
logits.
bos_id: beginning of sequence id
Returns:
a Tensor with shape [<batch_dims>, length_dim]
"""
inputs=partial_sequences# Partial sequences to fill in
batch_dims=inputs.shape.dims[:-1]
length_dim=inputs.shape.dims[-1]
padding_id=params.get("padding_id", 0)
slow_sampling=params.get("slow_sampling", False)
initial_position=mtf.reduce_sum(
mtf.to_int32(mtf.not_equal(inputs, padding_id)), reduced_dim=length_dim) # Gets position where zero padding starts
length_range=mtf.range(inputs.mesh, length_dim, tf.int32)
input_full_attention=True# for now hardcode this to true bc lazy
ifinput_full_attention:
# Vanilla autoregressive model - each position can see previous positions.
# Think this feeds in to the loop fn and tells each position where it can attend to?
read_priority=write_priority=length_range*mtf.to_int32(
mtf.greater(length_range, initial_position))
else:
read_priority=write_priority=length_range
# Builds context to pass around internally
# The 'first part' context records initial states of k / v / x
ifnotslow_sampling:
context_first_part=mtf_transformer.transformer.Context(
model=None,
mesh=inputs.mesh,
batch_dims=batch_dims,
length_dim=length_dim,
variable_dtype=variable_dtype,
mode="first_part",
position=length_range,
position_is_default=True,
new_states=[],
initial_position=initial_position,
sequence_id=None,
encoder_output=encoder_output,
encoder_sequence_id=encoder_sequence_id,
constant_states=[],
shared_params=shared_params,
encoder_layer_outputs=encoder_layer_outputs,
write_priority=write_priority,
read_priority=read_priority,
inputs=inputs,
encoder_inputs=encoder_inputs)
withtf.variable_scope("gpt2"):
logits, _, _=gpt2.model({"inputs": inputs}, other_features, params, inputs.mesh, variable_dtype=variable_dtype, context=context_first_part)
ifnothas_partial_sequences:
initial_states= [mtf.zeros_like(t) fortincontext_first_part.new_states]
else:
initial_states=context_first_part.new_states
else:
initial_states= []
ifnothas_partial_sequences:
partial_sequences_eos_count=0
ifstop_at_tokenisnotNone:
partial_sequences_eos_count=mtf.reduce_sum(
mtf.to_int32(mtf.equal(partial_sequences, stop_at_token)),
reduced_dim=length_dim)
defcond_fn(position, ids, *unused_states):
"""Should we run another loop iteration?"""
past_end=mtf.greater_equal(position, length_dim.size)
ifmax_steps:
past_end=mtf.logical_or(
past_end, mtf.greater_equal(position-initial_position, max_steps))
is_done=past_end
ifstop_at_tokenisnotNone:
eos_count=mtf.reduce_sum(
mtf.to_int32(mtf.equal(ids, stop_at_token)),
reduced_dim=length_dim)
has_additional_eos=mtf.greater(eos_count, partial_sequences_eos_count)
is_done=mtf.logical_or(is_done, has_additional_eos)
all_done=mtf.reduce_all(is_done)
returnmtf.logical_not(all_done)
defbody_fn(position, ids, *states):
"""One step in the decode loop."""
nonlocalsampling_keep_top_k
context=mtf_transformer.transformer.Context(
model=None,
mesh=inputs.mesh,
batch_dims=batch_dims,
length_dim=length_dim,
variable_dtype=variable_dtype,
mode="incremental",
position=position,
position_is_default=True,
states=states,
new_states=[],
initial_position=position,
sequence_id=None,
encoder_output=encoder_output,
encoder_sequence_id=encoder_sequence_id,
shared_params=shared_params,
encoder_layer_outputs=encoder_layer_outputs,
write_priority=write_priority,
read_priority=read_priority,
inputs=ids,
encoder_inputs=encoder_inputs) ifnotslow_samplingelseNone
withtf.variable_scope("gpt2", reuse=tf.AUTO_REUSE):
logits, _, _=gpt2.model({"inputs": ids}, other_features, params, inputs.mesh, variable_dtype=variable_dtype, context=context)
ifnotsampling_use_entmax:
# By default, do top_k sampling of 0.9
ifsampling_keep_top_k==-2:
sampling_keep_top_k=int(logits.shape[-1].size*0.1)
ifsampling_keep_top_k!=-1:
ifsampling_keep_top_k<=0:
raiseValueError("sampling_keep_top_k must either be -1 or positive.")
k_largest=mtf.nth_largest_element(
logits, n=sampling_keep_top_k,
reduced_dim=other_features["vocab_dim"])
logits=mtf.where(mtf.less_equal(logits, k_largest),
mtf.ones_like(logits) *-1e6, logits)
ids_this_step=mtf.sample_with_temperature(
logits, other_features["vocab_dim"], temperature)
else:
ids_this_step=sample_categorical(entmax(logits))
ifslow_sampling:
ids_this_step=mtf.shift(ids_this_step, offset=1, dim=length_dim, wrap=False)
else:
ids_this_step=mtf.reshape(ids_this_step, (batch_dims))
one_hot=mtf.one_hot(position, length_dim, dtype=tf.int32)
one_new_id=ids_this_step*one_hot
new_ids= (1-one_hot) *ids+one_new_id
new_position=position+1
ret= [new_position, new_ids]
ifcontextisnotNone:
ret+=context.new_states
returnret
while_loop_inputs= [initial_position, inputs] +initial_states
final_position, outputs=mtf.while_loop(
cond_fn, body_fn, while_loop_inputs)[:2]
delfinal_position
ifhas_partial_sequencesandremove_partial_sequences:
# Remove partial sequences from outputs
partial_length=mtf.reduce_sum(
mtf.to_int32(mtf.not_equal(partial_sequences, padding_id)),
reduced_dim=length_dim)
outputs=mtf.dynamic_shift(
outputs, -partial_length, length_dim, wrap=False)
returnoutputs