- Notifications
You must be signed in to change notification settings - Fork 228
/
Copy pathtest_quadruplets_classifiers.py
65 lines (58 loc) · 2.92 KB
/
test_quadruplets_classifiers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
importpytest
fromsklearn.exceptionsimportNotFittedError
fromsklearn.model_selectionimporttrain_test_split
fromtest.test_utilsimportquadruplets_learners, ids_quadruplets_learners
frommetric_learn.sklearn_shimsimportset_random_state
fromsklearnimportclone
importnumpyasnp
@pytest.mark.parametrize('with_preprocessor', [True, False])
@pytest.mark.parametrize('estimator, build_dataset', quadruplets_learners,
ids=ids_quadruplets_learners)
deftest_predict_only_one_or_minus_one(estimator, build_dataset,
with_preprocessor):
"""Test that all predicted values are either +1 or -1"""
input_data, labels, preprocessor, _=build_dataset(with_preprocessor)
estimator=clone(estimator)
estimator.set_params(preprocessor=preprocessor)
set_random_state(estimator)
(quadruplets_train,
quadruplets_test, y_train, y_test) =train_test_split(input_data, labels)
estimator.fit(quadruplets_train)
predictions=estimator.predict(quadruplets_test)
not_valid= [eforeinpredictionsifenotin [-1, 1]]
assertlen(not_valid) ==0
@pytest.mark.parametrize('with_preprocessor', [True, False])
@pytest.mark.parametrize('estimator, build_dataset', quadruplets_learners,
ids=ids_quadruplets_learners)
deftest_raise_not_fitted_error_if_not_fitted(estimator, build_dataset,
with_preprocessor):
"""Test that a NotFittedError is raised if someone tries to predict and
the metric learner has not been fitted."""
input_data, labels, preprocessor, _=build_dataset(with_preprocessor)
estimator=clone(estimator)
estimator.set_params(preprocessor=preprocessor)
set_random_state(estimator)
withpytest.raises(NotFittedError):
estimator.predict(input_data)
@pytest.mark.parametrize('estimator, build_dataset', quadruplets_learners,
ids=ids_quadruplets_learners)
deftest_accuracy_toy_example(estimator, build_dataset):
"""Test that the default scoring for quadruplets (accuracy) works on some
toy example"""
input_data, labels, preprocessor, X=build_dataset(with_preprocessor=False)
estimator=clone(estimator)
estimator.set_params(preprocessor=preprocessor)
set_random_state(estimator)
estimator.fit(input_data)
# We take the two first points and we build 4 regularly spaced points on the
# line they define, so that it's easy to build quadruplets of different
# similarities.
X_test=X[0] +np.arange(4)[:, np.newaxis] * (X[0] -X[1]) /4
quadruplets_test=np.array(
[[X_test[0], X_test[2], X_test[0], X_test[1]],
[X_test[1], X_test[3], X_test[1], X_test[0]],
[X_test[1], X_test[2], X_test[0], X_test[3]],
[X_test[3], X_test[0], X_test[2], X_test[1]]])
# we force the transformation to be identity so that we control what it does
estimator.components_=np.eye(X.shape[1])
assertestimator.score(quadruplets_test) ==0.25