- Notifications
You must be signed in to change notification settings - Fork 406
/
Copy pathsanitize_metadata.py
524 lines (429 loc) · 21.1 KB
/
sanitize_metadata.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
importargparse
fromaugur.ioimportopen_file, read_metadata
importcsv
importos
frompathlibimportPath
importpandasaspd
importre
importshutil
importsys
fromtempfileimportNamedTemporaryFile
fromutilsimportextract_tar_file_contents
# Define all possible geographic scales we could expect in the GISAID location
# field.
LOCATION_FIELDS= (
"region",
"country",
"division",
"location",
)
classMissingColumnException(Exception):
"""An exception caused by a missing column that was expected in the metadata.
"""
pass
classDuplicateException(Exception):
"""An exception caused by the presence of any duplicate metadata records by
strain name.
"""
pass
defparse_new_column_names(renaming_rules):
"""Parse the mapping of current to new column names from the given list of renaming rules.
Parameters
----------
renaming_rules : list[str]
A list of strings mapping an old column name to a new one delimited by an equal symbol (e.g., "old_column=new_column").
Returns
-------
dict :
A mapping of new column names for each old column name.
>>> parse_new_column_names(["old=new", "new=old"])
{'old': 'new', 'new': 'old'}
>>> parse_new_column_names(["old->new"])
{}
"""
new_column_names= {}
forruleinrenaming_rules:
if"="inrule:
old_column, new_column=rule.split("=")
new_column_names[old_column] =new_column
else:
print(
f"WARNING: missing mapping of old to new column in form of 'Virus name=strain' for rule: '{rule}'.",
file=sys.stderr
)
returnnew_column_names
defparse_location_string(location_string, location_fields):
"""Parse location string from GISAID into the given separate geographic scales
and return a dictionary of parse values by scale.
Parameters
----------
location_string : str
location_fields : list
Returns
-------
dict :
dictionary of geographic fields parsed from the given string
>>> location_fields = ["region", "country", "division", "location"]
>>> parse_location_string("Asia / Japan", location_fields)
{'region': 'Asia', 'country': 'Japan', 'division': '?', 'location': '?'}
>>> parse_location_string("Europe / Iceland / Reykjavik", location_fields)
{'region': 'Europe', 'country': 'Iceland', 'division': 'Reykjavik', 'location': '?'}
>>> parse_location_string("North America / USA / Washington / King County", location_fields)
{'region': 'North America', 'country': 'USA', 'division': 'Washington', 'location': 'King County'}
Additional location entries beyond what has been specified should be stripped from output.
>>> parse_location_string("North America / USA / Washington / King County / Extra field", location_fields)
{'region': 'North America', 'country': 'USA', 'division': 'Washington', 'location': 'King County'}
Trailing location delimiters should be stripped from the output.
>>> parse_location_string("North America / USA / Washington / King County / ", location_fields)
{'region': 'North America', 'country': 'USA', 'division': 'Washington', 'location': 'King County'}
Handle inconsistently delimited strings.
>>> parse_location_string("North America/USA/New York/New York", location_fields)
{'region': 'North America', 'country': 'USA', 'division': 'New York', 'location': 'New York'}
>>> parse_location_string("Europe/ Lithuania", location_fields)
{'region': 'Europe', 'country': 'Lithuania', 'division': '?', 'location': '?'}
"""
# Try to extract values for specific geographic scales.
values=re.split(r"[ ]*/[ ]*", location_string)
# Create a default mapping of location fields to missing values and update
# these from the values in the location string.
locations= {field: "?"forfieldinlocation_fields}
locations.update(dict(zip(location_fields, values)))
returnlocations
defstrip_prefixes(strain_name, prefixes):
"""Strip the given prefixes from the given strain name.
Parameters
----------
strain_name : str
Name of a strain to be sanitized
prefixes : list[str]
A list of prefixes to be stripped from the strain name.
Returns
-------
str :
Strain name without any of the given prefixes.
>>> strip_prefixes("hCoV-19/RandomStrain/1/2020", ["hCoV-19/", "SARS-CoV-2/"])
'RandomStrain/1/2020'
>>> strip_prefixes("SARS-CoV-2/RandomStrain/2/2020", ["hCoV-19/", "SARS-CoV-2/"])
'RandomStrain/2/2020'
>>> strip_prefixes("hCoV-19/RandomStrain/1/2020", ["SARS-CoV-2/"])
'hCoV-19/RandomStrain/1/2020'
"""
joined_prefixes="|".join(prefixes)
pattern=f"^({joined_prefixes})"
returnre.sub(pattern, "", strain_name)
defget_database_ids_by_strain(metadata_file, metadata_id_columns, database_id_columns, metadata_chunk_size, error_on_duplicates=False):
"""Get a mapping of all database ids for each strain name.
Parameters
----------
metadata_file : str or Path-like or file object
Path or file object for a metadata file to process.
metadata_id_columns : list[str]
A list of potential id columns for strain names in the metadata.
database_id_columns : list[str]
A list of potential database id columns whose values can be used to deduplicate records with the same strain name.
metadata_chunk_size : int
Number of records to read into memory at once from the metadata.
error_on_duplicates : bool
Throw an error when duplicate records are detected.
Returns
-------
str or Path-like or file object or None :
Path or file object containing the mapping of database ids for each
strain name (one row per combination). Returns None, if no valid
database ids were found and no duplicates exist.
Raises
------
DuplicateException :
When duplicates are detected and the caller has requested an error on duplicates.
MissingColumnException :
When none of the requested metadata id columns exist.
"""
try:
metadata_reader=read_metadata(
metadata_file,
id_columns=metadata_id_columns,
chunk_size=metadata_chunk_size,
)
exceptExceptionaserror:
# Augur's `read_metadata` function can throw a generic Exception when
# the input is missing id columns. This exception is not easily
# distinguished from any other error, so we check the contents of the
# error message and raise a more specific error for better handling of
# unexpected errors.
if"None of the possible id columns"instr(error):
raiseMissingColumnException(str(error)) fromerror
else:
raise
# Track strains we have observed, so we can alert the caller to duplicate
# strains when an error on duplicates has been requested.
observed_strains=set()
duplicate_strains=set()
withNamedTemporaryFile(delete=False, mode="wt", encoding="utf-8", newline="") asmapping_file:
mapping_path=mapping_file.name
header=True
formetadatainmetadata_reader:
metadata=sanitize_strain_names(metadata, args.strip_prefixes)
# Check for database id columns.
valid_database_id_columns=metadata.columns.intersection(
database_id_columns
)
ifmapping_pathandlen(valid_database_id_columns) ==0:
# Do not write out mapping of ids. Default to error on
# duplicates, since we have no way to resolve duplicates
# automatically.
mapping_path=None
error_on_duplicates=True
print(
"WARNING: Skipping deduplication of metadata records.",
f"None of the possible database id columns ({database_id_columns}) were found in the metadata's columns {tuple([metadata.index.name] +metadata.columns.values.tolist())}",
file=sys.stderr
)
# Track duplicates in memory, as needed.
iferror_on_duplicates:
forstraininmetadata.index.values:
ifstraininobserved_strains:
duplicate_strains.add(strain)
else:
observed_strains.add(strain)
ifmapping_path:
# Write mapping of database and strain ids to disk.
metadata.loc[:, valid_database_id_columns].to_csv(
mapping_file,
sep="\t",
header=header,
index=True,
)
header=False
# Clean up temporary file.
ifmapping_pathisNone:
os.unlink(mapping_file.name)
iferror_on_duplicatesandlen(duplicate_strains) >0:
duplicates_file=metadata_file+".duplicates.txt"
withopen(duplicates_file, "w") asoh:
forstraininduplicate_strains:
oh.write(f"{strain}\n")
raiseDuplicateException(f"{len(duplicate_strains)} strains have duplicate records. See '{duplicates_file}' for more details.")
returnmapping_path
defsanitize_strain_names(metadata, prefixes_to_strip):
"""Remove and replace certain characters in strain names.
Parameters
----------
metadata : pandas.DataFrame
A data frame indexed by strain name.
prefixes_to_strip : list[str]
A list of prefixes to be stripped from the strain name.
"""
# Reset the data frame index, to make the "strain" column available
# for transformation.
strain_field=metadata.index.name
metadata=metadata.reset_index()
# Strip prefixes from strain names.
ifprefixes_to_strip:
metadata[strain_field] =metadata[strain_field].apply(
lambdastrain: strip_prefixes(strain, prefixes_to_strip)
)
# Replace whitespaces from strain names with nothing to match Nextstrain's
# convention since whitespaces are not allowed in FASTA record names.
metadata[strain_field] =metadata[strain_field].str.replace(" ", "")
# Replace standard characters that are not accepted by all downstream
# tools as valid FASTA names.
metadata[strain_field] =metadata[strain_field].str.replace("'", "-")
# Set the index back to the strain column.
metadata=metadata.set_index(strain_field)
returnmetadata
deffilter_duplicates(metadata, database_ids_by_strain):
"""Filter duplicate records by the strain name in the given data frame index
using the given file containing a mapping of strain names to database ids.
Database ids allow us to identify duplicate records that need to be
excluded. We prefer the latest record for a given strain name across all
possible database ids and filter out all other records for that same strain
name.
Parameters
----------
metadata : pandas.DataFrame
A data frame indexed by strain name.
database_ids_by_strain : str or Path-like or file object
Path or file object containing the mapping of database ids for each strain name (one row per combination).
Returns
-------
pandas.DataFrame :
A filtered data frame with no duplicate records.
"""
# Get strain names for the given metadata.
strain_ids=set(metadata.index.values)
# Get the mappings of database ids to strain names for the current strains.
withopen(database_ids_by_strain, "r", encoding="utf-8", newline="") asfh:
reader=csv.DictReader(fh, delimiter="\t")
# The mapping file stores the strain name in the first column. All other
# fields are database ids.
strain_field=reader.fieldnames[0]
database_id_columns=reader.fieldnames[1:]
# Keep only records matching the current strain ids.
mappings=pd.DataFrame([
row
forrowinreader
ifrow[strain_field] instrain_ids
])
# Check for duplicate strains in the given metadata or strains that do not
# have any mappings. If there are none, return the metadata as it is. If
# duplicates or strains without mappings exist, filter them out.
ifany(mappings.duplicated(strain_field)) orlen(strain_ids) !=mappings.shape[0]:
# Create a list of database ids of records to keep. To this end, we sort by
# database ids in descending order such that the latest record appears
# first, then we take the first record for each strain name.
records_to_keep=mappings.sort_values(
database_id_columns,
ascending=False
).groupby(strain_field).first()
# Select metadata corresponding to database ids to keep. Database ids
# may not be unique for different strains (e.g., "?"), so we need to
# merge on strain name and database ids. Additionally, the same strain
# may appear multiple times in the metadata with the same id. These
# accidental duplicates will also produce a merge with records to keep
# that is not a one-to-one merge. To handle this case, we need to drop
# any remaining duplicate records by strain name. The order that we
# resolve these duplicates does not matter, since the fields we would
# use to resolve these contain identical values.
merge_columns=sorted(set([strain_field]) |set(database_id_columns))
metadata=metadata.reset_index().merge(
records_to_keep,
on=merge_columns,
).drop_duplicates(subset=strain_field).set_index(strain_field)
# Track strains that we've processed and drop these from the mappings file.
# In this way, we can track strains that have been processed across multiple
# chunks of metadata and avoid emiting duplicates that appear in different
# chunks.
withopen(database_ids_by_strain, "r", encoding="utf-8", newline="") asfh:
reader=csv.DictReader(fh, delimiter="\t")
withNamedTemporaryFile(delete=False, mode="wt", encoding="utf-8", newline="") asnew_mapping_file:
new_mapping_path=new_mapping_file.name
writer=csv.DictWriter(
new_mapping_file,
fieldnames=reader.fieldnames,
delimiter="\t",
lineterminator="\n",
)
writer.writeheader()
forrowinreader:
ifrow[strain_field] notinstrain_ids:
writer.writerow(row)
# After writing out the new mapping of ids without strains we just
# processed, copy the new mapping over the original file and delete the
# temporary new mapping file.
shutil.copyfile(
new_mapping_path,
database_ids_by_strain,
)
os.unlink(new_mapping_path)
returnmetadata
if__name__=='__main__':
parser=argparse.ArgumentParser(
usage="Sanitize metadata from different sources, applying operations (deduplicate, parse location field, strip prefixes, and rename fields) in the same order they appear in the full help (-h).",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--metadata", required=True, help="metadata to be sanitized")
parser.add_argument("--metadata-id-columns", default=["strain", "name", "Virus name"], nargs="+", help="names of valid metadata columns containing identifier information like 'strain' or 'name'")
parser.add_argument("--database-id-columns", default=["Accession ID", "gisaid_epi_isl", "genbank_accession"], nargs="+", help="names of metadata columns that store external database ids for each record (e.g., GISAID, GenBank, etc.) that can be used to deduplicate metadata records with the same strain names.")
parser.add_argument("--metadata-chunk-size", type=int, default=100000, help="maximum number of metadata records to read into memory at a time. Increasing this number can speed up filtering at the cost of more memory used.")
parser.add_argument("--error-on-duplicate-strains", action="store_true", help="exit with an error if any duplicate strains are found. By default, duplicates are resolved by preferring most recent accession id or the last record.")
parser.add_argument("--parse-location-field", help="split the given GISAID location field on '/' and create new columns for region, country, etc. based on available data. Replaces missing geographic data with '?' values.")
parser.add_argument("--strip-prefixes", nargs="+", help="prefixes to strip from strain names in the metadata")
parser.add_argument("--rename-fields", nargs="+", help="rename specific fields from the string on the left of the equal sign to the string on the right (e.g., 'Virus name=strain')")
parser.add_argument("--output", required=True, help="sanitized metadata")
args=parser.parse_args()
# Get user-defined metadata id columns to look for.
metadata_id_columns=args.metadata_id_columns
# Get user-defined database id columns to use for deduplication.
database_id_columns=args.database_id_columns
# If the input is a tarball, try to find a metadata file inside the archive.
metadata_file=args.metadata
metadata_is_temporary=False
if".tar"inPath(args.metadata).suffixes:
try:
temporary_dir, metadata_file=extract_tar_file_contents(
args.metadata,
"metadata"
)
metadata_is_temporary=True
exceptFileNotFoundErroraserror:
print(f"ERROR: {error}", file=sys.stderr)
sys.exit(1)
# In the first pass through the metadata, map strain names to database ids.
# We will use this mapping to deduplicate records in the second pass.
# Additionally, this pass checks for missing id columns and the presence of
# any duplicate records, in case the user has requested an error on
# duplicates.
try:
database_ids_by_strain=get_database_ids_by_strain(
metadata_file,
metadata_id_columns,
database_id_columns,
args.metadata_chunk_size,
args.error_on_duplicate_strains,
)
except (DuplicateException, MissingColumnException) aserror:
print(f"ERROR: {error}", file=sys.stderr)
sys.exit(1)
# Parse mapping of old column names to new.
rename_fields=args.rename_fieldsifargs.rename_fieldselse []
new_column_names=parse_new_column_names(rename_fields)
# In the second pass through the metadata, filter duplicate records,
# transform records with requested sanitizer steps, and stream the output to
# disk.
metadata_reader=read_metadata(
metadata_file,
id_columns=metadata_id_columns,
chunk_size=args.metadata_chunk_size,
)
emit_header=True
withopen_file(args.output, "w") asoutput_file_handle:
formetadatainmetadata_reader:
metadata=sanitize_strain_names(metadata, args.strip_prefixes)
ifdatabase_ids_by_strain:
# Filter duplicates. This should only happen after all
# transformations to the strain column.
metadata=filter_duplicates(
metadata,
database_ids_by_strain,
)
# Parse GISAID location field into separate fields for geographic
# scales. Replace missing field values with "?".
ifargs.parse_location_fieldandargs.parse_location_fieldinmetadata.columns:
locations=pd.DataFrame.from_dict(
{
strain: parse_location_string(location, LOCATION_FIELDS)
forstrain, locationinmetadata[args.parse_location_field].items()
}, orient='index'
)
locations.index.name=metadata.index.name
# Combine new location columns with original metadata and drop the
# original location column.
metadata=pd.concat(
[
metadata,
locations
],
axis=1
).drop(columns=[args.parse_location_field])
# Rename columns as needed, after transforming strain names. This
# allows us to avoid keeping track of a new strain name field
# provided by the user.
iflen(new_column_names) >0:
metadata=metadata.rename(columns=new_column_names)
ifmetadata.index.nameinnew_column_names:
metadata.index=metadata.index.rename(new_column_names[metadata.index.name])
# Write filtered and transformed metadata to the output file.
metadata.to_csv(
output_file_handle,
sep="\t",
index=True,
header=emit_header,
)
emit_header=False
ifdatabase_ids_by_strain:
# Delete the database/strain id mapping.
os.unlink(database_ids_by_strain)
# Clean up temporary directory and files that came from a tarball.
ifmetadata_is_temporary:
print(f"Cleaning up temporary files in {temporary_dir.name}", file=sys.stderr)
temporary_dir.cleanup()