Skip to content

Commit 84e9ce3

Browse files
Implement the mmaudio VAE. (comfyanonymous#10300)
1 parent f43b8ab commit 84e9ce3

File tree

9 files changed

+1247
-0
lines changed

9 files changed

+1247
-0
lines changed

comfy/ldm/mmaudio/vae/__init__.py

Whitespace-only changes.
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
2+
# LICENSE is in incl_licenses directory.
3+
4+
import torch
5+
from torch import nn, sin, pow
6+
from torch.nn import Parameter
7+
import comfy.model_management
8+
9+
class Snake(nn.Module):
10+
'''
11+
Implementation of a sine-based periodic activation function
12+
Shape:
13+
- Input: (B, C, T)
14+
- Output: (B, C, T), same shape as the input
15+
Parameters:
16+
- alpha - trainable parameter
17+
References:
18+
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
19+
https://arxiv.org/abs/2006.08195
20+
Examples:
21+
>>> a1 = snake(256)
22+
>>> x = torch.randn(256)
23+
>>> x = a1(x)
24+
'''
25+
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
26+
'''
27+
Initialization.
28+
INPUT:
29+
- in_features: shape of the input
30+
- alpha: trainable parameter
31+
alpha is initialized to 1 by default, higher values = higher-frequency.
32+
alpha will be trained along with the rest of your model.
33+
'''
34+
super(Snake, self).__init__()
35+
self.in_features = in_features
36+
37+
# initialize alpha
38+
self.alpha_logscale = alpha_logscale
39+
if self.alpha_logscale:
40+
self.alpha = Parameter(torch.empty(in_features))
41+
else:
42+
self.alpha = Parameter(torch.empty(in_features))
43+
44+
self.alpha.requires_grad = alpha_trainable
45+
46+
self.no_div_by_zero = 0.000000001
47+
48+
def forward(self, x):
49+
'''
50+
Forward pass of the function.
51+
Applies the function to the input elementwise.
52+
Snake ∶= x + 1/a * sin^2 (xa)
53+
'''
54+
alpha = comfy.model_management.cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
55+
if self.alpha_logscale:
56+
alpha = torch.exp(alpha)
57+
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
58+
59+
return x
60+
61+
62+
class SnakeBeta(nn.Module):
63+
'''
64+
A modified Snake function which uses separate parameters for the magnitude of the periodic components
65+
Shape:
66+
- Input: (B, C, T)
67+
- Output: (B, C, T), same shape as the input
68+
Parameters:
69+
- alpha - trainable parameter that controls frequency
70+
- beta - trainable parameter that controls magnitude
71+
References:
72+
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
73+
https://arxiv.org/abs/2006.08195
74+
Examples:
75+
>>> a1 = snakebeta(256)
76+
>>> x = torch.randn(256)
77+
>>> x = a1(x)
78+
'''
79+
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
80+
'''
81+
Initialization.
82+
INPUT:
83+
- in_features: shape of the input
84+
- alpha - trainable parameter that controls frequency
85+
- beta - trainable parameter that controls magnitude
86+
alpha is initialized to 1 by default, higher values = higher-frequency.
87+
beta is initialized to 1 by default, higher values = higher-magnitude.
88+
alpha will be trained along with the rest of your model.
89+
'''
90+
super(SnakeBeta, self).__init__()
91+
self.in_features = in_features
92+
93+
# initialize alpha
94+
self.alpha_logscale = alpha_logscale
95+
if self.alpha_logscale:
96+
self.alpha = Parameter(torch.empty(in_features))
97+
self.beta = Parameter(torch.empty(in_features))
98+
else:
99+
self.alpha = Parameter(torch.empty(in_features))
100+
self.beta = Parameter(torch.empty(in_features))
101+
102+
self.alpha.requires_grad = alpha_trainable
103+
self.beta.requires_grad = alpha_trainable
104+
105+
self.no_div_by_zero = 0.000000001
106+
107+
def forward(self, x):
108+
'''
109+
Forward pass of the function.
110+
Applies the function to the input elementwise.
111+
SnakeBeta ∶= x + 1/b * sin^2 (xa)
112+
'''
113+
alpha = comfy.model_management.cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
114+
beta = comfy.model_management.cast_to(self.beta, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1)
115+
if self.alpha_logscale:
116+
alpha = torch.exp(alpha)
117+
beta = torch.exp(beta)
118+
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
119+
120+
return x
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
import math
5+
import comfy.model_management
6+
7+
if 'sinc' in dir(torch):
8+
sinc = torch.sinc
9+
else:
10+
# This code is adopted from adefossez's julius.core.sinc under the MIT License
11+
# https://adefossez.github.io/julius/julius/core.html
12+
# LICENSE is in incl_licenses directory.
13+
def sinc(x: torch.Tensor):
14+
"""
15+
Implementation of sinc, i.e. sin(pi * x) / (pi * x)
16+
__Warning__: Different to julius.sinc, the input is multiplied by `pi`!
17+
"""
18+
return torch.where(x == 0,
19+
torch.tensor(1., device=x.device, dtype=x.dtype),
20+
torch.sin(math.pi * x) / math.pi / x)
21+
22+
23+
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
24+
# https://adefossez.github.io/julius/julius/lowpass.html
25+
# LICENSE is in incl_licenses directory.
26+
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
27+
even = (kernel_size % 2 == 0)
28+
half_size = kernel_size // 2
29+
30+
#For kaiser window
31+
delta_f = 4 * half_width
32+
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
33+
if A > 50.:
34+
beta = 0.1102 * (A - 8.7)
35+
elif A >= 21.:
36+
beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
37+
else:
38+
beta = 0.
39+
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
40+
41+
# ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
42+
if even:
43+
time = (torch.arange(-half_size, half_size) + 0.5)
44+
else:
45+
time = torch.arange(kernel_size) - half_size
46+
if cutoff == 0:
47+
filter_ = torch.zeros_like(time)
48+
else:
49+
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
50+
# Normalize filter to have sum = 1, otherwise we will have a small leakage
51+
# of the constant component in the input signal.
52+
filter_ /= filter_.sum()
53+
filter = filter_.view(1, 1, kernel_size)
54+
55+
return filter
56+
57+
58+
class LowPassFilter1d(nn.Module):
59+
def __init__(self,
60+
cutoff=0.5,
61+
half_width=0.6,
62+
stride: int = 1,
63+
padding: bool = True,
64+
padding_mode: str = 'replicate',
65+
kernel_size: int = 12):
66+
# kernel_size should be even number for stylegan3 setup,
67+
# in this implementation, odd number is also possible.
68+
super().__init__()
69+
if cutoff < -0.:
70+
raise ValueError("Minimum cutoff must be larger than zero.")
71+
if cutoff > 0.5:
72+
raise ValueError("A cutoff above 0.5 does not make sense.")
73+
self.kernel_size = kernel_size
74+
self.even = (kernel_size % 2 == 0)
75+
self.pad_left = kernel_size // 2 - int(self.even)
76+
self.pad_right = kernel_size // 2
77+
self.stride = stride
78+
self.padding = padding
79+
self.padding_mode = padding_mode
80+
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
81+
self.register_buffer("filter", filter)
82+
83+
#input [B, C, T]
84+
def forward(self, x):
85+
_, C, _ = x.shape
86+
87+
if self.padding:
88+
x = F.pad(x, (self.pad_left, self.pad_right),
89+
mode=self.padding_mode)
90+
out = F.conv1d(x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device),
91+
stride=self.stride, groups=C)
92+
93+
return out
94+
95+
96+
class UpSample1d(nn.Module):
97+
def __init__(self, ratio=2, kernel_size=None):
98+
super().__init__()
99+
self.ratio = ratio
100+
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
101+
self.stride = ratio
102+
self.pad = self.kernel_size // ratio - 1
103+
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
104+
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
105+
filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio,
106+
half_width=0.6 / ratio,
107+
kernel_size=self.kernel_size)
108+
self.register_buffer("filter", filter)
109+
110+
# x: [B, C, T]
111+
def forward(self, x):
112+
_, C, _ = x.shape
113+
114+
x = F.pad(x, (self.pad, self.pad), mode='replicate')
115+
x = self.ratio * F.conv_transpose1d(
116+
x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C)
117+
x = x[..., self.pad_left:-self.pad_right]
118+
119+
return x
120+
121+
122+
class DownSample1d(nn.Module):
123+
def __init__(self, ratio=2, kernel_size=None):
124+
super().__init__()
125+
self.ratio = ratio
126+
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
127+
self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio,
128+
half_width=0.6 / ratio,
129+
stride=ratio,
130+
kernel_size=self.kernel_size)
131+
132+
def forward(self, x):
133+
xx = self.lowpass(x)
134+
135+
return xx
136+
137+
class Activation1d(nn.Module):
138+
def __init__(self,
139+
activation,
140+
up_ratio: int = 2,
141+
down_ratio: int = 2,
142+
up_kernel_size: int = 12,
143+
down_kernel_size: int = 12):
144+
super().__init__()
145+
self.up_ratio = up_ratio
146+
self.down_ratio = down_ratio
147+
self.act = activation
148+
self.upsample = UpSample1d(up_ratio, up_kernel_size)
149+
self.downsample = DownSample1d(down_ratio, down_kernel_size)
150+
151+
# x: [B,C,T]
152+
def forward(self, x):
153+
x = self.upsample(x)
154+
x = self.act(x)
155+
x = self.downsample(x)
156+
157+
return x

0 commit comments

Comments
 (0)