Skip to content

Commit e9e63d6

Browse files
Add non-static and kt sampling (#280)
* kt sampling mask functions: `KtGaussian1DMaskFunc`, `KtRadialMaskFunc`, `KtUniformMaskFunc`, * Non-static sampling (dynamic/multislice) dicitated by the `MaskFuncMode`, which can be STATIC, MULTISLICE, DYNAMIC * Corresponding tests
1 parent d733b81 commit e9e63d6

File tree

8 files changed

+1711
-469
lines changed

8 files changed

+1711
-469
lines changed

direct/common/subsample.py

Lines changed: 1464 additions & 357 deletions
Large diffs are not rendered by default.

direct/common/subsample_config.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,23 @@
1-
# coding=utf-8
21
# Copyright (c) DIRECT Contributors
2+
3+
from __future__ import annotations
4+
35
from dataclasses import dataclass
4-
from typing import Optional, Tuple
6+
from typing import Optional
57

68
from omegaconf import MISSING
79

810
from direct.config.defaults import BaseConfig
11+
from direct.types import MaskFuncMode
912

1013

1114
@dataclass
1215
class MaskingConfig(BaseConfig):
1316
name: str = MISSING
14-
accelerations: Tuple[int, ...] = (5,) # Ideally Union[float, int].
15-
center_fractions: Optional[Tuple[float, ...]] = (0.1,) # Ideally Optional[Tuple[float, ...]]
17+
accelerations: tuple[float, ...] = (5.0,)
18+
center_fractions: Optional[tuple[float, ...]] = (0.1,)
1619
uniform_range: bool = False
17-
image_center_crop: bool = False
20+
mode: MaskFuncMode = MaskFuncMode.STATIC
1821

19-
val_accelerations: Tuple[int, ...] = (5, 10)
20-
val_center_fractions: Optional[Tuple[float, ...]] = (0.1, 0.05)
22+
val_accelerations: tuple[float, ...] = (5.0, 10.0)
23+
val_center_fractions: Optional[tuple[float, ...]] = (0.1, 0.05)

direct/data/mri_transforms.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def __call__(self, sample: dict[str, Any]) -> dict[str, Any]:
317317
Sample with `sampling_mask` key.
318318
"""
319319
if not self.shape:
320-
shape = sample["kspace"].shape[-3:]
320+
shape = sample["kspace"].shape[1:]
321321
elif any(_ is None for _ in self.shape): # Allow None as values.
322322
kspace_shape = list(sample["kspace"].shape[1:-1])
323323
shape = tuple(_ if _ else kspace_shape[idx] for idx, _ in enumerate(self.shape)) + (2,)
@@ -328,9 +328,6 @@ def __call__(self, sample: dict[str, Any]) -> dict[str, Any]:
328328

329329
sampling_mask = self.mask_func(shape=shape, seed=seed, return_acs=False)
330330

331-
if sample["kspace"].ndim == 5:
332-
sampling_mask = sampling_mask.unsqueeze(0)
333-
334331
if "padding" in sample:
335332
sampling_mask = T.apply_padding(sampling_mask, sample["padding"])
336333

direct/types.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from enum import Enum
99
from typing import NewType, Union
1010

11+
import numpy as np
1112
import torch
1213
from omegaconf.omegaconf import DictConfig
1314
from torch import nn as nn
@@ -19,6 +20,7 @@
1920
FileOrUrl = NewType("FileOrUrl", PathOrString)
2021
HasStateDict = Union[nn.Module, torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler, GradScaler]
2122
TensorOrNone = Union[None, torch.Tensor]
23+
TensorOrNdarray = Union[torch.Tensor, np.ndarray]
2224

2325

2426
class DirectEnum(str, Enum):
@@ -59,6 +61,12 @@ class TransformKey(DirectEnum):
5961
SCALING_FACTOR = "scaling_factor"
6062

6163

64+
class MaskFuncMode(DirectEnum):
65+
STATIC = "static"
66+
DYNAMIC = "dynamic"
67+
MULTISLICE = "multislice"
68+
69+
6270
class IntegerListOrTupleStringMeta(type):
6371
"""Metaclass for the :class:`IntegerListOrTupleString` class.
6472

tests/test_train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,13 @@
2525
)
2626
from direct.launch import launch
2727
from direct.train import setup_train
28+
from direct.types import MaskFuncMode
2829

2930

3031
def create_test_transform_cfg(transforms_type):
3132
transforms_config = TransformsConfig(
3233
normalization=NormalizationTransformConfig(scaling_key="masked_kspace"),
33-
masking=MaskingConfig(name="FastMRIRandom"),
34+
masking=MaskingConfig(name="FastMRIRandom", mode=MaskFuncMode.STATIC),
3435
cropping=CropTransformConfig(crop="(32, 32)"),
3536
sensitivity_map_estimation=SensitivityMapEstimationTransformConfig(estimate_sensitivity_maps=True),
3637
transforms_type=transforms_type,

0 commit comments

Comments
 (0)