- Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
190 lines (167 loc) · 6.95 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
importargparse
frompathlibimportPath
importsys
importtorchaudio
fromcompressimportcompress, decompress, MODELS
fromutilsimportsave_audio, convert_audio
# 定义文件后缀
SUFFIX='.ecdc'
defget_parser():
"""
创建命令行参数解析器。
返回:
argparse.ArgumentParser: 配置好的命令行参数解析器。
"""
parser=argparse.ArgumentParser(
'encodec',
description='High fidelity neural audio codec. '
'If input is a .ecdc, decompresses it. '
'If input is .wav, compresses it. If output is also wav, '
'do a compression/decompression cycle.')
# 添加输入文件参数
parser.add_argument(
'input', type=Path,
help='Input file, whatever is supported by torchaudio on your system.')
# 添加输出文件参数(可选)
parser.add_argument(
'output', type=Path, nargs='?',
help='Output file, otherwise inferred from input file.')
# 添加带宽参数
parser.add_argument(
'-b', '--bandwidth', type=float, default=6, choices=[1.5, 3., 6., 12., 24.],
help='Target bandwidth (1.5, 3, 6, 12 or 24). 1.5 is not supported with --hq.')
# 添加高质量模式参数
parser.add_argument(
'-q', '--hq', action='store_true',
help='Use HQ stereo model operating on 48 kHz sampled audio.')
# 添加语言模型参数
parser.add_argument(
'-l', '--lm', action='store_true',
help='Use a language model to reduce the model size (5x slower though).')
# 添加覆盖输出文件参数
parser.add_argument(
'-f', '--force', action='store_true',
help='Overwrite output file if it exists.')
# 添加解压缩后缀参数
parser.add_argument(
'-s', '--decompress_suffix', type=str, default='_decompressed',
help='Suffix for the decompressed output file (if no output path specified)')
# 添加自动缩放参数
parser.add_argument(
'-r', '--rescale', action='store_true',
help='Automatically rescale the output to avoid clipping.')
returnparser
deffatal(*args):
"""
打印错误信息到标准错误并退出程序。
Args:
*args: 可变数量的位置参数,打印为错误信息。
"""
print(*args, file=sys.stderr)
sys.exit(1)
defcheck_output_exists(args):
"""
检查输出路径是否存在。如果输出目录不存在,则终止程序。
如果输出文件已存在且未使用 -f / --force 参数,则终止程序。
Args:
args: 解析后的命令行参数。
"""
ifnotargs.output.parent.exists():
fatal(f"Output folder for {args.output} does not exist.")
ifargs.output.exists() andnotargs.force:
fatal(f"Output file {args.output} exist. Use -f / --force to overwrite.")
defcheck_clipping(wav, args):
"""
检查音频是否发生削波。如果音频的最大绝对值超过0.99,则发出警告。
如果使用了 -r / --rescale 参数,则不进行此检查。
Args:
wav (torch.Tensor): 要检查的音频张量。
args: 解析后的命令行参数。
"""
ifargs.rescale:
return
mx=wav.abs().max()
limit=0.99
ifmx>limit:
print(
f"Clipping!! max scale {mx}, limit is {limit}. "
"To avoid clipping, use the `-r` option to rescale the output.",
file=sys.stderr)
defmain():
"""
主函数,执行压缩或解压缩操作。
流程:
1. 解析命令行参数。
2. 检查输入文件是否存在。
3. 根据输入文件的后缀决定执行压缩还是解压缩。
4. 如果是解压缩:
- 如果未指定输出文件,则生成默认的输出文件名。
- 检查输出路径是否存在。
- 解压缩输入文件。
- 检查是否发生削波。
- 保存解压后的音频。
5. 如果是压缩:
- 如果未指定输出文件,则生成默认的输出文件名。
- 检查输出路径是否存在。
- 加载并转换音频文件。
- 压缩音频。
- 如果输出文件后缀为 .ecdc,则直接保存压缩后的数据。
- 如果输出文件后缀为 .wav,则解压缩并保存音频。
"""
# 解析命令行参数
args=get_parser().parse_args()
# 检查输入文件是否存在
ifnotargs.input.exists():
fatal(f"Input file {args.input} does not exist.")
# 判断输入文件的后缀是否为 .ecdc,如果是,则执行解压缩
ifargs.input.suffix.lower() ==SUFFIX:
# Decompression
ifargs.outputisNone:
# 如果未指定输出文件,则生成默认的输出文件名
args.output=args.input.with_name(args.input.stem+args.decompress_suffix).with_suffix('.wav')
elifargs.output.suffix.lower() !='.wav':
fatal("Output extension must be .wav")
# 检查输出路径是否存在
check_output_exists(args)
# 解压缩输入文件
out, out_sample_rate=decompress(args.input.read_bytes())
# 检查是否发生削波
check_clipping(out, args)
# 保存解压后的音频
save_audio(out, args.output, out_sample_rate, rescale=args.rescale)
else:
# Compression
ifargs.outputisNone:
# 如果未指定输出文件,则生成默认的输出文件名
args.output=args.input.with_suffix(SUFFIX)
elifargs.output.suffix.lower() notin [SUFFIX, '.wav']:
# 如果指定了输出文件但后缀不是 .wav 或 .ecdc,则终止程序
fatal(f"Output extension must be .wav or {SUFFIX}")
# 检查输出路径是否存在
check_output_exists(args)
# 选择模型
model_name='encodec_48khz'ifargs.hqelse'encodec_24khz'
# 加载模型
model=MODELS[model_name]()
# 检查带宽是否被模型支持
ifargs.bandwidthnotinmodel.target_bandwidths:
fatal(f"Bandwidth {args.bandwidth} is not supported by the model {model_name}")
# 设置模型的目标带宽
model.set_target_bandwidth(args.bandwidth)
# 加载并转换音频
wav, sr=torchaudio.load(args.input)
wav=convert_audio(wav, sr, model.sample_rate, model.channels)
# 压缩音频
compressed=compress(model, wav, use_lm=args.lm)
# 根据输出文件的后缀决定保存方式
ifargs.output.suffix.lower() ==SUFFIX:
args.output.write_bytes(compressed)
else:
# Directly run decompression stage
# 如果输出文件后缀为 .wav,则解压缩并保存音频
assertargs.output.suffix.lower() =='.wav'
out, out_sample_rate=decompress(compressed)
check_clipping(out, args)
save_audio(out, args.output, out_sample_rate, rescale=args.rescale)
if__name__=='__main__':
main()