- Notifications
You must be signed in to change notification settings - Fork 10.5k
/
Copy pathdifferential_operators.swift.gyb
64 lines (51 loc) · 2.17 KB
/
differential_operators.swift.gyb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
// RUN: %empty-directory(%t)
// RUN: %gyb %s -o %t/differential_operators.swift
// RUN: %target-build-swift %t/differential_operators.swift -o %t/differential_operators
// RUN: %target-codesign %t/differential_operators
// RUN: %target-run %t/differential_operators
// REQUIRES: executable_test
import _Differentiation
import StdlibUnittest
varDifferentialOperatorTestSuite=TestSuite("DifferentialOperator")
% for arity in range(1,3+1):
% params = ', '.join(['_ x%d: Float' % i for i in range(arity)])
% pb_return_type = '(' + ', '.join(['Float' for _ in range(arity)])+ ')'
func exampleDiffFunc_${arity}(${params})-> Float {
fatalError()
}
@derivative(of:exampleDiffFunc_${arity})
func exampleVJP_${arity}(${params})->(value: Float, pullback:(Float)-> ${pb_return_type}){
(
${' + '.join(['x%d * x%d' %(i, i) for i in range(arity)])},
{(${', '.join(['2* x%d * $0' % i for i in range(arity)])})}
)
}
% argValues =[i *10 for i in range(1, arity +1)]
% args = ', '.join([str(v) for v in argValues])
% expectedValue =sum([v * v for v in argValues])
% expectedGradientValues =[2* v for v in argValues]
% expectedGradients = '(' + ', '.join([str(g) for g in expectedGradientValues])+ ')'
DifferentialOperatorTestSuite.test("valueWithPullback_${arity}"){
let(value, pb)=valueWithPullback(at: ${args}, of:exampleDiffFunc_${arity})
expectEqual(${expectedValue}, value)
expectEqual(${expectedGradients},pb(1))
}
DifferentialOperatorTestSuite.test("pullback_${arity}"){
letpb=pullback(at: ${args}, of:exampleDiffFunc_${arity})
expectEqual(${expectedGradients},pb(1))
}
DifferentialOperatorTestSuite.test("gradient_${arity}"){
letgrad=gradient(at: ${args}, of:exampleDiffFunc_${arity})
expectEqual(${expectedGradients}, grad)
}
DifferentialOperatorTestSuite.test("valueWithGradient_${arity}"){
let(value, grad)=valueWithGradient(at: ${args}, of:exampleDiffFunc_${arity})
expectEqual(${expectedValue}, value)
expectEqual(${expectedGradients}, grad)
}
DifferentialOperatorTestSuite.test("gradient_curried_${arity}"){
letgradF=gradient(of:exampleDiffFunc_${arity})
expectEqual(${expectedGradients},gradF(${args}))
}
% end
runAllTests()