-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Expand file tree
/
Copy paths3prl.py
More file actions
117 lines (97 loc) · 4.19 KB
/
s3prl.py
File metadata and controls
117 lines (97 loc) · 4.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import copy
import logging
from typing import Optional, Tuple, Union
import humanfriendly
import torch
from typeguard import typechecked
from espnet2.asr.frontend.abs_frontend import AbsFrontend
from espnet2.legacy.nets.pytorch_backend.frontends.frontend import Frontend
from espnet2.utils.get_default_kwargs import get_default_kwargs
class S3prlFrontend(AbsFrontend):
"""Speech Pretrained Representation frontend structure for ASR."""
@typechecked
def __init__(
self,
fs: Union[int, str] = 16000,
frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
download_dir: Optional[str] = None,
multilayer_feature: bool = False,
layer: int = -1,
):
try:
import s3prl
from s3prl.nn import Featurizer, S3PRLUpstream
except Exception as e:
print("Error: S3PRL is not properly installed.")
print("Please install S3PRL: cd ${MAIN_ROOT}/tools && make s3prl.done")
raise e
super().__init__()
if isinstance(fs, str):
fs = humanfriendly.parse_size(fs)
if fs != 16000:
logging.warning(
"All the upstream models in S3PRL now only support 16 kHz audio."
)
if download_dir is not None:
s3prl.util.download.set_dir(download_dir)
assert frontend_conf.get("upstream", None) in S3PRLUpstream.available_names()
upstream = S3PRLUpstream(
frontend_conf.get("upstream"),
path_or_url=frontend_conf.get("path_or_url", None),
normalize=frontend_conf.get("normalize", False),
extra_conf=frontend_conf.get("extra_conf", None),
)
if getattr(upstream.upstream, "model", None):
if getattr(upstream.upstream.model, "feature_grad_mult", None) is not None:
upstream.upstream.model.feature_grad_mult = 1.0
upstream.eval()
if layer != -1:
layer_selections = [layer]
assert (
not multilayer_feature
), "multilayer feature will be deactivated, when specific layer used"
else:
layer_selections = None
featurizer = Featurizer(upstream, layer_selections=layer_selections)
self.multilayer_feature = multilayer_feature
self.layer = layer
self.upstream, self.featurizer = upstream, featurizer
self.pretrained_params = copy.deepcopy(self.upstream.state_dict())
self.frontend_type = "s3prl"
self.hop_length = self.featurizer.downsample_rate
self.tile_factor = frontend_conf.get("tile_factor", 1)
def _tile_representations(self, feature):
"""Tile up the representations by `tile_factor`.
Input - sequence of representations
shape: (batch_size, seq_len, feature_dim)
Output - sequence of tiled representations
shape: (batch_size, seq_len * factor, feature_dim)
"""
assert (
len(feature.shape) == 3
), "Input argument `feature` has invalid shape: {}".format(feature.shape)
tiled_feature = feature.repeat(1, 1, self.tile_factor)
tiled_feature = tiled_feature.reshape(
feature.size(0), feature.size(1) * self.tile_factor, feature.size(2)
)
return tiled_feature
def output_size(self) -> int:
return self.featurizer.output_size
def forward(
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
feats, feats_lens = self.upstream(input, input_lengths)
if self.layer != -1:
layer = self.layer
feats, feats_lens = feats[layer], feats_lens[layer]
return feats, feats_lens
if self.multilayer_feature:
feats, feats_lens = self.featurizer(feats, feats_lens)
else:
feats, feats_lens = self.featurizer(feats[-1:], feats_lens[-1:])
if self.tile_factor != 1:
feats = self._tile_representations(feats)
return feats, feats_lens
def reload_pretrained_parameters(self):
self.upstream.load_state_dict(self.pretrained_params)
logging.info("Pretrained S3PRL frontend model parameters reloaded!")