@@ -110,23 +110,43 @@ def __init__(self, num_channels, head_dim, operation_settings=None):
110110 self .key_norm = operations .RMSNorm (head_dim , device = operation_settings .get ("device" ), dtype = operation_settings .get ("dtype" ))
111111
112112 self .out_layer = operations .Linear (num_channels , num_channels , bias = True , device = operation_settings .get ("device" ), dtype = operation_settings .get ("dtype" ))
113+ self .num_chunks = 2
113114
114- def forward (self , x , freqs , transformer_options = {}):
115- def compute_q (x ):
116- q = self .to_query (x ).view (* x .shape [:- 1 ], self .num_heads , - 1 )
117- return apply_rope1 (self .query_norm (q ), freqs )
118-
119- def compute_k (x ):
120- k = self .to_key (x ).view (* x .shape [:- 1 ], self .num_heads , - 1 )
121- return apply_rope1 (self .key_norm (k ), freqs )
115+ def _compute_qk (self , x , freqs , proj_fn , norm_fn ):
116+ result = proj_fn (x ).view (* x .shape [:- 1 ], self .num_heads , - 1 )
117+ return apply_rope1 (norm_fn (result ), freqs )
122118
123- q = compute_q (x )
124- k = compute_k (x )
119+ def _forward (self , x , freqs , transformer_options = {}):
120+ q = self ._compute_qk (x , freqs , self .to_query , self .query_norm )
121+ k = self ._compute_qk (x , freqs , self .to_key , self .key_norm )
122+ v = self .to_value (x ).view (* x .shape [:- 1 ], self .num_heads , - 1 )
123+ out = attention (q , k , v , self .num_heads , transformer_options = transformer_options )
124+ return self .out_layer (out )
125125
126+ def _forward_chunked (self , x , freqs , transformer_options = {}):
127+ def process_chunks (proj_fn , norm_fn ):
128+ B , L , _ = x .shape
129+ chunk_size = (L + self .num_chunks - 1 ) // self .num_chunks
130+ chunks = []
131+ for i in range (0 , L , chunk_size ):
132+ end_idx = min (i + chunk_size , L )
133+ x_chunk = x [:, i :end_idx ]
134+ freqs_chunk = freqs [:, i :end_idx ]
135+ chunks .append (self ._compute_qk (x_chunk , freqs_chunk , proj_fn , norm_fn ))
136+ return torch .cat (chunks , dim = 1 )
137+
138+ q = process_chunks (self .to_query , self .query_norm )
139+ k = process_chunks (self .to_key , self .key_norm )
126140 v = self .to_value (x ).view (* x .shape [:- 1 ], self .num_heads , - 1 )
127141 out = attention (q , k , v , self .num_heads , transformer_options = transformer_options )
128142 return self .out_layer (out )
129143
144+ def forward (self , x , freqs , transformer_options = {}):
145+ if x .shape [1 ] > 8192 :
146+ return self ._forward_chunked (x , freqs , transformer_options = transformer_options )
147+ else :
148+ return self ._forward (x , freqs , transformer_options = transformer_options )
149+
130150
131151class CrossAttention (SelfAttention ):
132152 def get_qkv (self , x , context ):
@@ -150,22 +170,24 @@ def __init__(self, dim, ff_dim, operation_settings=None):
150170 self .out_layer = operations .Linear (ff_dim , dim , bias = False , device = operation_settings .get ("device" ), dtype = operation_settings .get ("dtype" ))
151171 self .num_chunks = 4
152172
153- def forward (self , x ):
154- #return self.out_layer(self.activation(self.in_layer(x)))
155- # ffn is the peak memory consumer, chunking here helps
156- B , L , C = x .shape
173+ def _forward (self , x ):
174+ return self .out_layer (self .activation (self .in_layer (x )))
175+
176+ def _forward_chunked (self , x ):
177+ B , L , _ = x .shape
157178 chunk_size = (L + self .num_chunks - 1 ) // self .num_chunks
158179 output = torch .empty (B , L , self .out_layer .out_features , dtype = x .dtype , device = x .device )
159-
160180 for i in range (0 , L , chunk_size ):
161181 end_idx = min (i + chunk_size , L )
162- def compute_chunk (x_chunk ):
163- activated = self .activation (self .in_layer (x_chunk ))
164- return self .out_layer (activated )
165- output [:, i :end_idx ] = compute_chunk (x [:, i :end_idx ])
166-
182+ output [:, i :end_idx ] = self ._forward (x [:, i :end_idx ])
167183 return output
168184
185+ def forward (self , x ):
186+ if x .shape [1 ] > 8192 :
187+ return self ._forward_chunked (x )
188+ else :
189+ return self ._forward (x )
190+
169191
170192class OutLayer (nn .Module ):
171193 def __init__ (self , model_dim , time_dim , visual_dim , patch_size , operation_settings = None ):
0 commit comments