- Notifications
You must be signed in to change notification settings - Fork 332
/
Copy pathcpuCNN.cu
65 lines (60 loc) · 2.79 KB
/
cpuCNN.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#include"softmaxtree.cuh"
/*
* weights: (numNodes, numFeatures)
* targets: (numNodes, numFeatures)
*
*/
voidcpuSoftmaxTreeFwd(float* weights, float* targets, constint numFeatures, SoftmaxTree& tree) {
for (int d = 0; d <= tree.getDepth(); ++d) {
for (SoftmaxNodeV::iterator it = tree.getNodesAtDepth(d).begin(); it!= tree.getNodesAtDepth(d).end(); ++it) {
SoftmaxNode& node = **it;
SoftmaxNode* parent = node.getParent();
for (int f = 0; f < numFeatures; ++f) {
targets[node.getLabel() * numFeatures + f] = weights[node.getLabel() * numFeatures + f]
+ (parent == NULL ? 0 : targets[parent->getLabel() * numFeatures + f]);
}
}
}
}
/*
* grads: (numNodes, numFeatures)
*
*/
voidcpuSoftmaxTreeBwd(float* grads, constint numFeatures, SoftmaxTree& tree) {
for (int h = 1; h <= tree.getHeight(); ++h) {
for (SoftmaxNodeV::iterator it = tree.getNodesAtHeight(h).begin(); it!= tree.getNodesAtHeight(h).end(); ++it) {
SoftmaxNode& node = **it;
for (int f = 0; f < numFeatures; ++f) {
grads[node.getLabel() * numFeatures + f] = 0;
}
for (SoftmaxNodeV::iterator itc = node.getChildren().begin(); itc!= node.getChildren().end(); ++itc) {
SoftmaxNode& child = **itc;
for (int f = 0; f < numFeatures; ++f) {
grads[node.getLabel() * numFeatures + f] += grads[child.getLabel() * numFeatures + f];
}
}
}
}
}
/*
* weights: (numNodes, numFeatures)
* weightsInc: (numNodes, numFeatures)
* weightsGrad: (numNodes, numFeatures)
* nodeSizes: numNodes-array whose ith element gives number of leaves under
* node with label i.
*/
voidcpuSoftmaxTreeUpdateWeights(float* weights, float* weightsInc, float* weightsGrad,
constint numFeatures, float eps, constfloat mom, float wc, SoftmaxTree& tree) {
for (int d = 0; d <= tree.getDepth(); d++) {
for (SoftmaxNodeV::iterator it = tree.getNodesAtDepth(d).begin(); it!= tree.getNodesAtDepth(d).end(); ++it) {
SoftmaxNode& node = **it;
float w = wc / node.getSize();
float e = eps;// * sqrt(node.getSize());
for (int f = 0; f < numFeatures; ++f) {
weightsInc[node.getLabel() * numFeatures + f] = mom * weightsInc[node.getLabel() * numFeatures + f]
+ e * (weightsGrad[node.getLabel() * numFeatures + f] - w * weights[node.getLabel() * numFeatures + f]);
weights[node.getLabel() * numFeatures + f] += weightsInc[node.getLabel() * numFeatures + f];
}
}
}
}