1+ from collections .abc import Iterable
12from functools import partial
23from itertools import islice , cycle
34
@@ -21,9 +22,7 @@ def default(val, d):
2122 return val if exists (val ) else d
2223
2324def cast_tuple (val , depth = 1 ):
24- if isinstance (val , list ):
25- val = tuple (val )
26- return val if isinstance (val , tuple ) else (val ,) * depth
25+ return val if isinstance (val , Iterable ) else (val ,) * depth
2726
2827# classes
2928
@@ -184,15 +183,16 @@ def __init__(
184183 else :
185184 raise ValueError (f'attention type "{ attn_type } " is not valid' )
186185
187- attn = shared_attn_layers .get (attn_id )
186+ attn , reused_attn_type = shared_attn_layers .get (attn_id , ( None , None ) )
188187 if not exists (attn ):
189188 if attn_type != 'mlp' :
190189 attn = attn_class (dim , causal = causal , seq_len = seq_len , heads = heads , dim_head = dim_head , dropout = attn_dropout )
191190 else :
192191 attn = attn_class (dim = dim , causal = causal , dim_ff = dim * 4 )
193- shared_attn_layers [attn_id ] = attn
194- else :
195- assert isinstance (attn , attn_class ), 'attn_types do not match shared_attn_ids'
192+ shared_attn_layers [attn_id ] = (attn , attn_type )
193+ elif attn_type != reused_attn_type :
194+ raise ValueError ('attn_types do not match shared_attn_ids '
195+ f'(ind = { ind } , attn_type = "{ attn_type } ", reused_attn_type = "{ reused_attn_type } ")' )
196196
197197 ff = shared_ff_layers .get (ff_id )
198198 if not exists (ff ):
0 commit comments