- Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathautoencoder.py
71 lines (51 loc) · 2.36 KB
/
autoencoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
importtorch
importtorch.nnasnn
device=torch.device("cuda"iftorch.cuda.is_available() else"cpu")
classEncoderRNN(nn.Module):
def__init__(self, input_size, embedding_size, hidden_size):
super(EncoderRNN, self).__init__()
self.hidden_size=hidden_size
self.linear=nn.Linear(embedding_size, 1)
self.gru=nn.GRU(hidden_size, hidden_size)
defforward(self, input, hidden):
input_squeezed=self.linear(input)
output, hidden=self.gru(input_squeezed.view(input_squeezed.size()[:-1]), hidden)
returnoutput, hidden
definitHidden(self, batch_size):
returntorch.zeros(1, batch_size, self.hidden_size, device=device)
classDecoderRNN(nn.Module):
def__init__(self, hidden_size, embedding_size, output_size):
super(DecoderRNN, self).__init__()
self.hidden_size=hidden_size
self.gru=nn.GRUCell(hidden_size, hidden_size)
self.out=nn.Linear(hidden_size, output_size)
defforward(self, input, hidden, steps):
output= []
foriinrange(steps):
hidden=self.gru(input, hidden)
input=self.out(hidden)
output.append(input)
returntorch.stack(output, 0)
definitHidden(self, batch_size):
returntorch.zeros(batch_size, self.hidden_size, device=device)
deftrain_autoencoder(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion,
max_length):
encoder_hidden=encoder.initHidden(input_tensor.shape[1])
encoder_optimizer.zero_grad()
decoder_optimizer.zero_grad()
input_length=input_tensor.size(0)
target_length=target_tensor.size(0)
encoder_outputs=torch.zeros(max_length, encoder.hidden_size, device=device)
encoder_output, encoder_hidden=encoder(
input_tensor, encoder_hidden)
encoder_outputs=encoder_output[0, 0]
decoder_input=torch.zeros(target_tensor.size()[1:], device=device, dtype=torch.float)
# decoder_input = torch.cat([decoder_input, target_tensor[:-1]], 0)
decoder_hidden=encoder_hidden[0]
decoder_output=decoder(
decoder_input, decoder_hidden, steps=target_tensor.shape[0])
loss=criterion(decoder_output, target_tensor)
loss.backward()
encoder_optimizer.step()
decoder_optimizer.step()
returnloss.item() /target_length