-
Notifications
You must be signed in to change notification settings - Fork 0
/
Utility_functions.py
42 lines (30 loc) · 1.08 KB
/
Utility_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from transformers import AutoProcessor, Wav2Vec2Model
import numpy as np
import torch
def normalize(x, epsilon=1e-8):
"""
Simple function that normalises the audio samples.
"""
maxim, mini = x.max(), x.min()
x_normed = (x - mini) / (
(maxim - mini) + epsilon
) # epsilon to avoid dividing by zero
return x_normed
def exctract_wav2vec(audios, sampling_rate=16000):
"""
Wrapper for the wav2vec model, outputs only the hidden state (features).
"""
processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
features = []
L = audios.shape[0]
for i in range(L):
print("sample {} / {}".format(i, L), end="\r")
inputs = processor(
audios[i, :], sampling_rate=sampling_rate, return_tensors="pt"
) # Batch size 1
with torch.no_grad():
outputs = model(**inputs)
last_hidden_states = outputs.last_hidden_state
features.append(last_hidden_states.detach().numpy())
return features