@@ -663,15 +663,25 @@ def test_equality_operator(tmp_path, setup_random_dwi_data):
663663 bzero = b0_dataobj ,
664664 )
665665
666+ # Get all user-defined, named attributes
667+ attrs_to_check = [
668+ a .name for a in attrs .fields (DWI ) if not a .name .startswith ("_" ) and not a .name .isdigit ()
669+ ]
670+
666671 # Sanity checks (element-wise)
667- assert np .allclose (dwi_obj_direct .dataobj , dwi_obj_from_nii .dataobj )
668- assert np .allclose (dwi_obj_direct .affine , dwi_obj_from_nii .affine )
669- if dwi_obj_direct .brainmask is None or dwi_obj_from_nii .brainmask is None :
670- assert dwi_obj_direct .brainmask is None
671- assert dwi_obj_from_nii .brainmask is None
672- else :
673- assert np .array_equal (dwi_obj_direct .brainmask , dwi_obj_from_nii .brainmask )
674- assert np .allclose (dwi_obj_direct .gradients , dwi_obj_from_nii .gradients )
672+ for attr_name in attrs_to_check :
673+ val_direct = getattr (dwi_obj_direct , attr_name )
674+ val_from_nii = getattr (dwi_obj_from_nii , attr_name )
675+
676+ if val_direct is None or val_from_nii is None :
677+ assert val_direct is None and val_from_nii is None , f"{ attr_name } mismatch"
678+ else :
679+ if isinstance (val_direct , np .ndarray ):
680+ assert val_direct .shape == val_from_nii .shape
681+ assert np .allclose (val_direct , val_from_nii ), f"{ attr_name } arrays differ"
682+ else :
683+ assert val_direct == val_from_nii , f"{ attr_name } values differ"
684+
675685 # Properties derived from gradients should also match
676686 assert np .allclose (dwi_obj_direct .bvals , dwi_obj_from_nii .bvals )
677687 assert np .allclose (dwi_obj_direct .bvecs , dwi_obj_from_nii .bvecs )
0 commit comments