- Notifications
You must be signed in to change notification settings - Fork 406
/
Copy pathget_distance_to_focal_set.py
198 lines (155 loc) · 7.63 KB
/
get_distance_to_focal_set.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
"""
Calculate minimal distances between sequences in an alignment and a set of focal sequences
"""
importargparse
fromaugur.ioimportread_sequences
fromrandomimportshuffle
fromcollectionsimportdefaultdict
importBio
importnumpyasnp
fromBio.SeqIO.FastaIOimportSimpleFastaParser
fromBio.SeqimportSeq
fromBioimportAlignIO, SeqIO
fromscipyimportsparse
importsys
defcompactify_sequences(sparse_matrix, sequence_names):
sequence_groups=defaultdict(list)
fors, snpsinzip(sequence_names, sparse_matrix):
ind=snps.nonzero()
vals=np.array(snps[ind])
iflen(ind[1]):
sequence_groups[tuple(zip(ind[1], vals[0]))].append(s)
else:
sequence_groups[tuple()].append(s)
returnsequence_groups
INITIALISATION_LENGTH=1000000
defsequence_to_int_array(s, fill_value=110, fill_gaps=True):
seq=np.frombuffer(str(s).lower().encode('utf-8'), dtype=np.int8).copy()
iffill_gaps:
seq[(seq!=97) & (seq!=99) & (seq!=103) & (seq!=116)] =fill_value
else:
seq[(seq!=97) & (seq!=99) & (seq!=103) & (seq!=116) & (seq!=45)] =fill_value
returnseq
# Function adapted from https://github.com/gtonkinhill/pairsnp-python
defcalculate_snp_matrix(fastafile, consensus=None, zipped=False, fill_value=110, chunk_size=0, ignore_seqs=None):
# This function generate a sparse matrix where differences to the consensus are coded as integers.
ifignore_seqsisNone:
ignore_seqs= []
row=np.empty(INITIALISATION_LENGTH)
col=np.empty(INITIALISATION_LENGTH, dtype=np.int64)
val=np.empty(INITIALISATION_LENGTH, dtype=np.int8)
r=0
n_snps=0
nseqs=0
seq_names= []
filled_positions= []
current_length=INITIALISATION_LENGTH
forrecordinfastafile:
h=record.name
s=str(record.seq)
ifhinignore_seqs:
continue
ifconsensusisNone:
align_length=len(s)
# Take consensus as first sequence
consensus=sequence_to_int_array(s, fill_value=fill_value)
else:
align_length=len(consensus)
nseqs+=1
seq_names.append(h)
if(len(s)!=align_length):
raiseValueError('Fasta file appears to have sequences of different lengths!')
s=sequence_to_int_array(s, fill_value=fill_value)
snps= (consensus!=s) & (s!=fill_value)
right=n_snps+np.sum(snps)
filled_positions.append(np.where(s==fill_value)[0])
ifright>= (current_length/2):
current_length=current_length+INITIALISATION_LENGTH
row.resize(current_length)
col.resize(current_length)
val.resize(current_length)
row[n_snps:right] =r
col[n_snps:right] =np.flatnonzero(snps)
val[n_snps:right] =s[snps]
r+=1
n_snps=right
ifchunk_sizeandchunk_size==nseqs:
break
ifnseqs==0:
returnNone
row=row[0:right]
col=col[0:right]
val=val[0:right]
sparse_snps=sparse.csc_matrix((val, (row, col)), shape=(nseqs, align_length))
return {'snps': sparse_snps, 'consensus': consensus, 'names': seq_names, 'filled_positions': filled_positions}
# Function adapted from https://github.com/gtonkinhill/pairsnp-python
defcalculate_distance_matrix(sparse_matrix_A, sparse_matrix_B, consensus):
n_seqs_A=sparse_matrix_A.shape[0]
n_seqs_B=sparse_matrix_B.shape[0]
d= (1*(sparse_matrix_A==97)) * (sparse_matrix_B.transpose()==97)
d=d+ (1*(sparse_matrix_A==99) * (sparse_matrix_B.transpose()==99))
d=d+ (1*(sparse_matrix_A==103) * (sparse_matrix_B.transpose()==103))
d=d+ (1*(sparse_matrix_A==116) * (sparse_matrix_B.transpose()==116))
d=d.todense()
n_comp= (1*(sparse_matrix_A==110) * ((sparse_matrix_B==110).transpose())).todense()
d=d+n_comp
temp_total=np.zeros((n_seqs_A, n_seqs_B))
temp_total[:] = (1*(sparse_matrix_A>0)).sum(1)
temp_total+= (1*(sparse_matrix_B>0)).sum(1).transpose()
total_differences_shared= (1*(sparse_matrix_A>0)) * (sparse_matrix_B.transpose()>0)
n_total=np.zeros((n_seqs_A, n_seqs_B))
n_sum= (1*(sparse_matrix_A==110)).sum(1)
n_total[:] =n_sum
n_total+= (1*(sparse_matrix_B==110)).sum(1).transpose()
diff_n=n_total-2*n_comp
d=temp_total-total_differences_shared.todense() -d-diff_n
returnd
if__name__=='__main__':
parser=argparse.ArgumentParser(
description="generate priorities files based on genetic proximity to focal sample",
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument("--alignment", type=str, required=True, help="FASTA file of alignment")
parser.add_argument("--reference", type=str, required=True, help="reference sequence (FASTA)")
parser.add_argument("--ignore-seqs", type=str, nargs='+', help="sequences to ignore in distance calculation")
parser.add_argument("--focal-alignment", type=str, required=True, help="focal sample of sequences")
parser.add_argument("--chunk-size", type=int, default=10000, help="number of samples in the global alignment to process at once. Reduce this number to reduce memory usage at the cost of increased run-time.")
parser.add_argument("--output", type=str, required=True, help="FASTA file of output alignment")
args=parser.parse_args()
# load entire alignment and the alignment of focal sequences (upper case -- probably not necessary)
ref=sequence_to_int_array(SeqIO.read(args.reference, 'fasta').seq)
alignment_length=len(ref)
focal_seqs=read_sequences(args.focal_alignment)
focal_seqs_dict=calculate_snp_matrix(focal_seqs, consensus=ref, ignore_seqs=args.ignore_seqs)
iffocal_seqs_dictisNone:
print(
f"ERROR: There are no valid sequences in the focal alignment, '{args.focal_alignment}', to compare against the full alignment.",
"Check your subsampling settings for the focal alignment or consider disabling proximity-based subsampling.",
file=sys.stderr
)
sys.exit(1)
seqs=read_sequences(args.alignment)
# export priorities
fh_out=open(args.output, 'w')
fh_out.write('strain\tclosest strain\tdistance\n')
chunk_size=args.chunk_size
chunk_count=0
whileTrue:
context_seqs_dict=calculate_snp_matrix(seqs, consensus=ref, chunk_size=chunk_size)
ifcontext_seqs_dictisNone:
break
print("Reading the alignments.", chunk_count*chunk_size)
# calculate number of masked sites in either set
mask_count_focal=np.array([len(x) forxinfocal_seqs_dict['filled_positions']])
mask_count_context= {s: len(x) fors,xinzip(context_seqs_dict['names'], context_seqs_dict['filled_positions'])}
# for each context sequence, calculate minimal distance to focal set, weigh with number of N/- to pick best sequence
d=np.array(calculate_distance_matrix(context_seqs_dict['snps'], focal_seqs_dict['snps'], consensus=context_seqs_dict['consensus']))
closest_match=np.argmin(d+mask_count_focal/alignment_length, axis=1)
print("Done finding closest matches.")
minimal_distance_to_focal_set= {}
forcontext_index, focal_indexinenumerate(closest_match):
minimal_distance_to_focal_set[context_seqs_dict['names'][context_index]] = (d[context_index, focal_index], focal_seqs_dict["names"][focal_index])
forseqidinminimal_distance_to_focal_set:
fh_out.write(f"{seqid}\t{minimal_distance_to_focal_set[seqid][1]}\t{minimal_distance_to_focal_set[seqid][0]}\n")
chunk_count+=1
fh_out.close()