- Notifications
You must be signed in to change notification settings - Fork 10.4k
/
Copy pathtrain_searcher.py
147 lines (117 loc) · 5.67 KB
/
train_searcher.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
importos, sys
importnumpyasnp
importscann
importargparse
importglob
frommultiprocessingimportcpu_count
fromtqdmimporttqdm
fromldm.utilimportparallel_data_prefetch
defsearch_bruteforce(searcher):
returnsearcher.score_brute_force().build()
defsearch_partioned_ah(searcher, dims_per_block, aiq_threshold, reorder_k,
partioning_trainsize, num_leaves, num_leaves_to_search):
returnsearcher.tree(num_leaves=num_leaves,
num_leaves_to_search=num_leaves_to_search,
training_sample_size=partioning_trainsize). \
score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(reorder_k).build()
defsearch_ah(searcher, dims_per_block, aiq_threshold, reorder_k):
returnsearcher.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(
reorder_k).build()
defload_datapool(dpath):
defload_single_file(saved_embeddings):
compressed=np.load(saved_embeddings)
database= {key: compressed[key] forkeyincompressed.files}
returndatabase
defload_multi_files(data_archive):
database= {key: [] forkeyindata_archive[0].files}
fordintqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'):
forkeyind.files:
database[key].append(d[key])
returndatabase
print(f'Load saved patch embedding from "{dpath}"')
file_content=glob.glob(os.path.join(dpath, '*.npz'))
iflen(file_content) ==1:
data_pool=load_single_file(file_content[0])
eliflen(file_content) >1:
data= [np.load(f) forfinfile_content]
prefetched_data=parallel_data_prefetch(load_multi_files, data,
n_proc=min(len(data), cpu_count()), target_data_type='dict')
data_pool= {key: np.concatenate([od[key] forodinprefetched_data], axis=1)[0] forkeyinprefetched_data[0].keys()}
else:
raiseValueError(f'No npz-files in specified path "{dpath}" is this directory existing?')
print(f'Finished loading of retrieval database of length {data_pool["embedding"].shape[0]}.')
returndata_pool
deftrain_searcher(opt,
metric='dot_product',
partioning_trainsize=None,
reorder_k=None,
# todo tune
aiq_thld=0.2,
dims_per_block=2,
num_leaves=None,
num_leaves_to_search=None,):
data_pool=load_datapool(opt.database)
k=opt.knn
ifnotreorder_k:
reorder_k=2*k
# normalize
# embeddings =
searcher=scann.scann_ops_pybind.builder(data_pool['embedding'] /np.linalg.norm(data_pool['embedding'], axis=1)[:, np.newaxis], k, metric)
pool_size=data_pool['embedding'].shape[0]
print(*(['#'] *100))
print('Initializing scaNN searcher with the following values:')
print(f'k: {k}')
print(f'metric: {metric}')
print(f'reorder_k: {reorder_k}')
print(f'anisotropic_quantization_threshold: {aiq_thld}')
print(f'dims_per_block: {dims_per_block}')
print(*(['#'] *100))
print('Start training searcher....')
print(f'N samples in pool is {pool_size}')
# this reflects the recommended design choices proposed at
# https://github.com/google-research/google-research/blob/aca5f2e44e301af172590bb8e65711f0c9ee0cfd/scann/docs/algorithms.md
ifpool_size<2e4:
print('Using brute force search.')
searcher=search_bruteforce(searcher)
elif2e4<=pool_sizeandpool_size<1e5:
print('Using asymmetric hashing search and reordering.')
searcher=search_ah(searcher, dims_per_block, aiq_thld, reorder_k)
else:
print('Using using partioning, asymmetric hashing search and reordering.')
ifnotpartioning_trainsize:
partioning_trainsize=data_pool['embedding'].shape[0] //10
ifnotnum_leaves:
num_leaves=int(np.sqrt(pool_size))
ifnotnum_leaves_to_search:
num_leaves_to_search=max(num_leaves//20, 1)
print('Partitioning params:')
print(f'num_leaves: {num_leaves}')
print(f'num_leaves_to_search: {num_leaves_to_search}')
# self.searcher = self.search_ah(searcher, dims_per_block, aiq_thld, reorder_k)
searcher=search_partioned_ah(searcher, dims_per_block, aiq_thld, reorder_k,
partioning_trainsize, num_leaves, num_leaves_to_search)
print('Finish training searcher')
searcher_savedir=opt.target_path
os.makedirs(searcher_savedir, exist_ok=True)
searcher.serialize(searcher_savedir)
print(f'Saved trained searcher under "{searcher_savedir}"')
if__name__=='__main__':
sys.path.append(os.getcwd())
parser=argparse.ArgumentParser()
parser.add_argument('--database',
'-d',
default='data/rdm/retrieval_databases/openimages',
type=str,
help='path to folder containing the clip feature of the database')
parser.add_argument('--target_path',
'-t',
default='data/rdm/searchers/openimages',
type=str,
help='path to the target folder where the searcher shall be stored.')
parser.add_argument('--knn',
'-k',
default=20,
type=int,
help='number of nearest neighbors, for which the searcher shall be optimized')
opt, _=parser.parse_known_args()
train_searcher(opt,)