Skip to content

Commit 48dcedb

Browse files
committed
2 parents 5f77a4d + cc06786 commit 48dcedb

File tree

3 files changed

+251
-3
lines changed

3 files changed

+251
-3
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,11 @@ This repository contains code for the following Keras models:
66
- VGG19
77
- ResNet50
88
- Inception v3
9+
- CRNN for music tagging
910

1011
All architectures are compatible with both TensorFlow and Theano, and upon instantiation the models will be built according to the image dimension ordering set in your Keras configuration file at `~/.keras/keras.json`. For instance, if you have set `image_dim_ordering=tf`, then any model loaded from this repository will get built according to the TensorFlow dimension ordering convention, "Width-Height-Depth".
1112

12-
Weights can be automatically loaded upon instantiation (`weights='imagenet'` argument in model constructor). Weights are automatically downloaded if necessary, and cached locally in `~/.keras/models/`.
13-
14-
**Note that using these models requires the latest version of Keras (from the Github repo, not PyPI).**
13+
Pre-trained weights can be automatically loaded upon instantiation (`weights='imagenet'` argument in model constructor for all image models, `weights='msd'` for the music tagging model). Weights are automatically downloaded if necessary, and cached locally in `~/.keras/models/`.
1514

1615
## Examples
1716

@@ -78,6 +77,7 @@ block4_pool_features = model.predict(x)
7877
- [Very Deep Convolutional Networks for Large-Scale Image Recognition](https://arxiv.org/abs/1409.1556) - please cite this paper if you use the VGG models in your work.
7978
- [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) - please cite this paper if you use the ResNet model in your work.
8079
- [Rethinking the Inception Architecture for Computer Vision](http://arxiv.org/abs/1512.00567) - please cite this paper if you use the Inception v3 model in your work.
80+
- [Music-auto_tagging-keras](https://github.com/keunwoochoi/music-auto_tagging-keras)
8181

8282
Additionally, don't forget to [cite Keras](https://keras.io/getting-started/faq/#how-should-i-cite-keras) if you use these models.
8383

audio_conv_utils.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import numpy as np
2+
from keras import backend as K
3+
4+
5+
TAGS = ['rock', 'pop', 'alternative', 'indie', 'electronic',
6+
'female vocalists', 'dance', '00s', 'alternative rock', 'jazz',
7+
'beautiful', 'metal', 'chillout', 'male vocalists',
8+
'classic rock', 'soul', 'indie rock', 'Mellow', 'electronica',
9+
'80s', 'folk', '90s', 'chill', 'instrumental', 'punk',
10+
'oldies', 'blues', 'hard rock', 'ambient', 'acoustic',
11+
'experimental', 'female vocalist', 'guitar', 'Hip-Hop',
12+
'70s', 'party', 'country', 'easy listening',
13+
'sexy', 'catchy', 'funk', 'electro', 'heavy metal',
14+
'Progressive rock', '60s', 'rnb', 'indie pop',
15+
'sad', 'House', 'happy']
16+
17+
18+
def librosa_exists():
19+
try:
20+
__import__('librosa')
21+
except ImportError:
22+
return False
23+
else:
24+
return True
25+
26+
27+
def preprocess_input(audio_path, dim_ordering='default'):
28+
'''Reads an audio file and outputs a Mel-spectrogram.
29+
'''
30+
if dim_ordering == 'default':
31+
dim_ordering = K.image_dim_ordering()
32+
assert dim_ordering in {'tf', 'th'}
33+
34+
if librosa_exists():
35+
import librosa
36+
else:
37+
raise RuntimeError('Librosa is required to process audio files.\n' +
38+
'Install it via `pip install librosa` \nor visit ' +
39+
'http://librosa.github.io/librosa/ for details.')
40+
41+
# mel-spectrogram parameters
42+
SR = 12000
43+
N_FFT = 512
44+
N_MELS = 96
45+
HOP_LEN = 256
46+
DURA = 29.12
47+
48+
src, sr = librosa.load(audio_path, sr=SR)
49+
n_sample = src.shape[0]
50+
n_sample_wanted = int(DURA * SR)
51+
52+
# trim the signal at the center
53+
if n_sample < n_sample_wanted: # if too short
54+
src = np.hstack((src, np.zeros((int(DURA * SR) - n_sample,))))
55+
elif n_sample > n_sample_wanted: # if too long
56+
src = src[(n_sample - n_sample_wanted) / 2:
57+
(n_sample + n_sample_wanted) / 2]
58+
59+
logam = librosa.logamplitude
60+
melgram = librosa.feature.melspectrogram
61+
x = logam(melgram(y=src, sr=SR, hop_length=HOP_LEN,
62+
n_fft=N_FFT, n_mels=N_MELS) ** 2,
63+
ref_power=1.0)
64+
65+
if dim_ordering == 'th':
66+
x = np.expand_dims(x, axis=0)
67+
elif dim_ordering == 'tf':
68+
x = np.expand_dims(x, axis=3)
69+
return x
70+
71+
72+
def decode_predictions(preds, top_n=5):
73+
'''Decode the output of a music tagger model.
74+
75+
# Arguments
76+
preds: 2-dimensional numpy array
77+
top_n: integer in [0, 50], number of items to show
78+
79+
'''
80+
assert len(preds.shape) == 2 and preds.shape[1] == 50
81+
results = []
82+
for pred in preds:
83+
result = zip(TAGS, pred)
84+
result = sorted(result, key=lambda x: x[1], reverse=True)
85+
results.append(result[:top_n])
86+
return results

music_tagger_crnn.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
# -*- coding: utf-8 -*-
2+
'''MusicTaggerCRNN model for Keras.
3+
4+
Code by github.com/keunwoochoi.
5+
6+
# Reference:
7+
8+
- [Music-auto_tagging-keras](https://github.com/keunwoochoi/music-auto_tagging-keras)
9+
10+
'''
11+
from __future__ import print_function
12+
from __future__ import absolute_import
13+
14+
import numpy as np
15+
from keras import backend as K
16+
from keras.layers import Input, Dense
17+
from keras.models import Model
18+
from keras.layers import Dense, Dropout, Reshape, Permute
19+
from keras.layers.convolutional import Convolution2D
20+
from keras.layers.convolutional import MaxPooling2D, ZeroPadding2D
21+
from keras.layers.normalization import BatchNormalization
22+
from keras.layers.advanced_activations import ELU
23+
from keras.layers.recurrent import GRU
24+
from keras.utils.data_utils import get_file
25+
from keras.utils.layer_utils import convert_all_kernels_in_model
26+
from audio_conv_utils import decode_predictions, preprocess_input
27+
28+
TH_WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.3/music_tagger_crnn_weights_tf_kernels_th_dim_ordering.h5'
29+
TF_WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.3/music_tagger_crnn_weights_tf_kernels_tf_dim_ordering.h5'
30+
31+
32+
def MusicTaggerCRNN(weights='msd', input_tensor=None,
33+
include_top=True):
34+
'''Instantiate the MusicTaggerCRNN architecture,
35+
optionally loading weights pre-trained
36+
on Million Song Dataset. Note that when using TensorFlow,
37+
for best performance you should set
38+
`image_dim_ordering="tf"` in your Keras config
39+
at ~/.keras/keras.json.
40+
41+
The model and the weights are compatible with both
42+
TensorFlow and Theano. The dimension ordering
43+
convention used by the model is the one
44+
specified in your Keras config file.
45+
46+
For preparing mel-spectrogram input, see
47+
`audio_conv_utils.py` in [applications](https://github.com/fchollet/keras/tree/master/keras/applications).
48+
You will need to install [Librosa](http://librosa.github.io/librosa/)
49+
to use it.
50+
51+
# Arguments
52+
weights: one of `None` (random initialization)
53+
or "msd" (pre-training on ImageNet).
54+
input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)
55+
to use as image input for the model.
56+
include_top: whether to include the 1 fully-connected
57+
layer (output layer) at the top of the network.
58+
If False, the network outputs 32-dim features.
59+
60+
61+
# Returns
62+
A Keras model instance.
63+
'''
64+
if weights not in {'msd', None}:
65+
raise ValueError('The `weights` argument should be either '
66+
'`None` (random initialization) or `msd` '
67+
'(pre-training on Million Song Dataset).')
68+
69+
# Determine proper input shape
70+
if K.image_dim_ordering() == 'th':
71+
input_shape = (1, 96, 1366)
72+
else:
73+
input_shape = (96, 1366, 1)
74+
75+
if input_tensor is None:
76+
melgram_input = Input(shape=input_shape)
77+
else:
78+
if not K.is_keras_tensor(input_tensor):
79+
melgram_input = Input(tensor=input_tensor, shape=input_shape)
80+
else:
81+
melgram_input = input_tensor
82+
83+
# Determine input axis
84+
if K.image_dim_ordering() == 'th':
85+
channel_axis = 1
86+
freq_axis = 2
87+
time_axis = 3
88+
else:
89+
channel_axis = 3
90+
freq_axis = 1
91+
time_axis = 2
92+
93+
# Input block
94+
x = ZeroPadding2D(padding=(0, 37))(melgram_input)
95+
x = BatchNormalization(axis=time_axis, name='bn_0_freq')(x)
96+
97+
# Conv block 1
98+
x = Convolution2D(64, 3, 3, border_mode='same', name='conv1')(x)
99+
x = BatchNormalization(axis=channel_axis, mode=0, name='bn1')(x)
100+
x = ELU()(x)
101+
x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), name='pool1')(x)
102+
103+
# Conv block 2
104+
x = Convolution2D(128, 3, 3, border_mode='same', name='conv2')(x)
105+
x = BatchNormalization(axis=channel_axis, mode=0, name='bn2')(x)
106+
x = ELU()(x)
107+
x = MaxPooling2D(pool_size=(3, 3), strides=(3, 3), name='pool2')(x)
108+
109+
# Conv block 3
110+
x = Convolution2D(128, 3, 3, border_mode='same', name='conv3')(x)
111+
x = BatchNormalization(axis=channel_axis, mode=0, name='bn3')(x)
112+
x = ELU()(x)
113+
x = MaxPooling2D(pool_size=(4, 4), strides=(4, 4), name='pool3')(x)
114+
115+
# Conv block 4
116+
x = Convolution2D(128, 3, 3, border_mode='same', name='conv4')(x)
117+
x = BatchNormalization(axis=channel_axis, mode=0, name='bn4')(x)
118+
x = ELU()(x)
119+
x = MaxPooling2D(pool_size=(4, 4), strides=(4, 4), name='pool4')(x)
120+
121+
# reshaping
122+
if K.image_dim_ordering() == 'th':
123+
x = Permute((3, 1, 2))(x)
124+
x = Reshape((15, 128))(x)
125+
126+
# GRU block 1, 2, output
127+
x = GRU(32, return_sequences=True, name='gru1')(x)
128+
x = GRU(32, return_sequences=False, name='gru2')(x)
129+
130+
if include_top:
131+
x = Dense(50, activation='sigmoid', name='output')(x)
132+
133+
# Create model
134+
model = Model(melgram_input, x)
135+
if weights is None:
136+
return model
137+
else:
138+
# Load weights
139+
if K.image_dim_ordering() == 'tf':
140+
weights_path = get_file('music_tagger_crnn_weights_tf_kernels_tf_dim_ordering.h5',
141+
TF_WEIGHTS_PATH,
142+
cache_subdir='models')
143+
else:
144+
weights_path = get_file('music_tagger_crnn_weights_tf_kernels_th_dim_ordering.h5',
145+
TH_WEIGHTS_PATH,
146+
cache_subdir='models')
147+
model.load_weights(weights_path, by_name=True)
148+
if K.backend() == 'theano':
149+
convert_all_kernels_in_model(model)
150+
return model
151+
152+
153+
if __name__ == '__main__':
154+
model = MusicTaggerCRNN(weights='msd')
155+
156+
audio_path = 'audio_file.mp3'
157+
melgram = preprocess_input(audio_path)
158+
melgrams = np.expand_dims(melgram, axis=0)
159+
160+
preds = model.predict(melgrams)
161+
print('Predicted:')
162+
print(decode_predictions(preds))

0 commit comments

Comments
 (0)