- Notifications
You must be signed in to change notification settings - Fork 332
/
Copy pathmultisoftmax.cpp
126 lines (112 loc) · 4.38 KB
/
multisoftmax.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
#include<assert.h>
//#include <mathimf.h>
#include<multisoftmax.h>
usingnamespacestd;
// Computes log(exp(x) + exp(y))
inlinedoublelogadd(constdouble x, constdouble y) {
if (x <= -INF && y <= -INF) {
return -INF;
}
constdouble M = max(x,y);
constdouble m = min(x,y);
constdouble diff = M - m;
// return diff > 15 ? M : M + LOG(1.0 + EXP(-diff));
// return m <= -INF ? M : M + LOG(1.0f + EXP(-diff));
return diff > 15 ? M : (diff > 5 ? M + EXP(-diff) : M + LOG(1.0 + EXP(-diff)));
}
/*
* elts: (numCases, numOut)
* B: (N + 1, size + 1) -- batckward lattice matrix, MUST BE initially -INF
* fixed: (numCases, 1)
* probs: (numCases, numOut) (*out)
*
* double precision is much faster than single. :/
*/
voidMultiSoftmaxCPU_T_logspace(Matrix& elts, Matrix& logB, Matrix& probs, Matrix& fixed, int size, bool nofix) {
int numCases = elts.getNumRows();
assert(probs.isSameDims(elts));
assert(!elts.isTrans());
assert(!logB.isTrans());
assert(!probs.isTrans());
assert(fixed.getNumRows() == numCases);
assert(fixed.getNumCols() == 1);
int N = elts.getNumCols();
Matrix& logF = *newMatrix(size + 1, 1); // Forward column
// Prepare logB
logB(N, 0) = 0;
for (int c = 0; c < numCases; ++c) {
int fx = nofix ? -1 : int(fixed(c, 0));
// Backward pass
for (int i = N - 1; i >= 0; --i) {
double elt = elts(c, i);
logB(i, 0) = i <= fx ? -INF : 0.0f;
for (int s = max(1, size - i); s < size + 1; ++s) {
logB(i, s) = fx == i ? logB(i + 1, s - 1) + elt : logadd(logB(i + 1, s - 1) + elt, logB(i + 1, s));
}
}
// Log partition function
double logZ = logB(0, size);
// Forward pass
logF.apply(Matrix::ONE);
logF.scale(-INF);
logF(0, 0) = 0;
for (int i = 1; i < N + 1; ++i) {
double logy = -INF;
double elt = elts(c, i - 1);
for (int s = size; s >= 0; --s) {
if (s < size) {
logy = logadd(logy, logF(s, 0) + logB(i, size - 1 - s));
}
if (s > 0) {
logF(s, 0) = fx == i - 1 ? logF(s - 1, 0) + elt : logadd(logF(s - 1, 0) + elt, logF(s, 0));
} elseif (fx == i - 1) {
logF(0, 0) = -INF;
}
}
logy += elt - logZ;
probs(c, i - 1) = EXP(logy) - (fx >= 0 ? probs(c, i - 1) : 0);
}
}
delete &logF;
}
MultiSoftmaxWorker::MultiSoftmaxWorker(Matrix* elts, Matrix* B, Matrix* probs, Matrix* fixed, int size, bool nofix)
: Thread(true), _elts(elts), _B(B), _probs(probs), _fixed(fixed), _size(size), _nofix(nofix) {
}
MultiSoftmaxWorker::~MultiSoftmaxWorker() {
delete _elts;
delete _probs;
delete _fixed;
}
void* MultiSoftmaxWorker::run() {
MultiSoftmaxCPU_T_logspace(*_elts, *_B, *_probs, *_fixed, _size, _nofix);
returnNULL;
}
/*
* elts: (numCases, numOut)
* B: vector of (N + 1, size + 1) -- batckward lattice matrix, should be initially zero
* fixed: (numCases, 1)
* probs: (numCases, numOut) (*out)
*
* NOTE: remember to write a version of this for transposed matrices.
* It may end up being significantly faster, which is important if
* I plan to use CPU for this.
*/
voidMultiSoftmaxCPU_T_parallel(Matrix& elts, vector<Matrix*>& B, Matrix& probs, Matrix& fixed, int size, bool nofix) {
int numCases = elts.getNumRows();
int numWorkers = min(numCases, (int)B.size());
probs.resize(elts);
int casesPerWorker = DIVUP(numCases, B.size());
numWorkers = min(numWorkers, DIVUP(numCases, casesPerWorker));
vector<Thread*> workers;
for (int i = 0; i < numWorkers; ++i) {
Matrix* eltSlice = &elts.sliceRows(i * casesPerWorker, min(elts.getNumRows(), (longint)(i + 1) * casesPerWorker));
Matrix* probSlice = &probs.sliceRows(i * casesPerWorker, min(elts.getNumRows(), (longint)(i + 1) * casesPerWorker));
Matrix* fixedSlice = &fixed.sliceRows(i * casesPerWorker, min(elts.getNumRows(), (longint)(i + 1) * casesPerWorker));
workers.push_back(newMultiSoftmaxWorker(eltSlice, B[i], probSlice, fixedSlice, size, nofix));
workers[i]->start();
}
for (int i = 0; i < numWorkers; ++i) {
workers[i]->join();
delete workers[i];
}
}