Skip to content

Commit 4355037

Browse files
Adds 3D vSHARP model (#273)
* Adding 3D UNet & config * Add 3d vsharp & config * Minor fixes in typing and software package versions
1 parent b660012 commit 4355037

File tree

21 files changed

+1122
-49
lines changed

21 files changed

+1122
-49
lines changed

Makefile

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ BROWSER := python -c "$$BROWSER_PYSCRIPT"
2626
help:
2727
@python -c "$$PRINT_HELP_PYSCRIPT" < $(MAKEFILE_LIST)
2828

29-
clean: clean-build clean-pyc clean-cpy clean-test clean-docs ## remove all build, test, coverage, docs and Python and cython artifacts
29+
clean: clean-build clean-pyc clean-cpy clean-ipynb clean-test clean-docs ## remove all build, test, coverage, docs and Python and cython artifacts
3030

3131
clean-build: ## remove build artifacts
3232
rm -fr build/
@@ -46,6 +46,9 @@ clean-cpy: ## remove cython file artifacts
4646
find . -name '*.cpp' -exec rm -f {} +
4747
find . -name '*.so' -exec rm -f {} +
4848

49+
clean-ipynb: ## remove ipynb artifacts
50+
find . -name '.ipynb_checkpoints' -exec rm -rf {} +
51+
4952
clean-test: ## remove test and coverage artifacts
5053
rm -fr .tox/
5154
rm -f .coverage

direct/nn/unet/config.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
# coding=utf-8
21
# Copyright (c) DIRECT Contributors
2+
33
from dataclasses import dataclass
44

55
from direct.config.defaults import ModelConfig
6+
from direct.nn.types import InitType
67

78

89
@dataclass
@@ -30,4 +31,13 @@ class Unet2dConfig(ModelConfig):
3031
dropout_probability: float = 0.0
3132
skip_connection: bool = False
3233
normalized: bool = False
33-
image_initialization: str = "zero_filled"
34+
image_initialization: InitType = InitType.ZERO_FILLED
35+
36+
37+
@dataclass
38+
class UnetModel3dConfig(ModelConfig):
39+
in_channels: int = 2
40+
out_channels: int = 2
41+
num_filters: int = 16
42+
num_pool_layers: int = 4
43+
dropout_probability: float = 0.0

direct/nn/unet/unet_2d.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from torch.nn import functional as F
1111

1212
from direct.data import transforms as T
13+
from direct.nn.types import InitType
1314

1415

1516
class ConvBlock(nn.Module):
@@ -334,7 +335,7 @@ def __init__(
334335
dropout_probability: float,
335336
skip_connection: bool = False,
336337
normalized: bool = False,
337-
image_initialization: str = "zero_filled",
338+
image_initialization: InitType = InitType.ZERO_FILLED,
338339
**kwargs,
339340
):
340341
"""Inits :class:`Unet2d`.
@@ -355,8 +356,8 @@ def __init__(
355356
If True, skip connection is used for the output. Default: False.
356357
normalized: bool
357358
If True, Normalized Unet is used. Default: False.
358-
image_initialization: str
359-
Type of image initialization. Default: "zero-filled".
359+
image_initialization: InitType
360+
Type of image initialization. Default: InitType.ZERO_FILLED.
360361
kwargs: dict
361362
"""
362363
super().__init__()
@@ -437,18 +438,18 @@ def forward(
437438
output: torch.Tensor
438439
Output image of shape (N, height, width, complex=2).
439440
"""
440-
if self.image_initialization == "sense":
441+
if self.image_initialization == InitType.SENSE:
441442
if sensitivity_map is None:
442-
raise ValueError("Expected sensitivity_map not to be None with 'sense' image_initialization.")
443+
raise ValueError("Expected sensitivity_map not to be None with InitType.SENSE image_initialization.")
443444
input_image = self.compute_sense_init(
444445
kspace=masked_kspace,
445446
sensitivity_map=sensitivity_map,
446447
)
447-
elif self.image_initialization == "zero_filled":
448+
elif self.image_initialization == InitType.ZERO_FILLED:
448449
input_image = self.backward_operator(masked_kspace, dim=self._spatial_dims).sum(self._coil_dim)
449450
else:
450451
raise ValueError(
451-
f"Unknown image_initialization. Expected `sense` or `zero_filled`. "
452+
f"Unknown image_initialization. Expected InitType.ZERO_FILLED or InitType.SENSE. "
452453
f"Got {self.image_initialization}."
453454
)
454455

0 commit comments

Comments
 (0)