Skip to content

Commit 1463f92

Browse files
Also test deep learning for longitudinal testing
1 parent dfa7c4f commit 1463f92

File tree

2 files changed

+30
-4
lines changed

2 files changed

+30
-4
lines changed

conftest.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
import numpy as np
77
from phantoms.MR_XCAT_qMRI.sim_ivim_sig import phantom
88
import warnings
9+
import os
10+
import torch
11+
import random
912
from tests.IVIMmodels.unit_tests.test_ivim_fit import PerformanceWarning
1013
warnings.simplefilter("always", PerformanceWarning)
1114

@@ -98,6 +101,21 @@ def pytest_addoption(parser):
98101
help="Run MATLAB-dependent tests"
99102
)
100103

104+
def set_global_seed(seed: int = 42):
105+
os.environ["PYTHONHASHSEED"] = str(seed)
106+
random.seed(seed)
107+
np.random.seed(seed)
108+
torch.manual_seed(seed)
109+
torch.cuda.manual_seed(seed)
110+
torch.cuda.manual_seed_all(seed)
111+
torch.backends.cudnn.deterministic = True
112+
torch.backends.cudnn.benchmark = False
113+
print(f"✅ Global seed set to {seed}")
114+
115+
@pytest.fixture(autouse=True, scope="session")
116+
def global_seed():
117+
"""Automatically seed all random generators at test session start."""
118+
set_global_seed(1234)
101119

102120
@pytest.fixture(scope="session")
103121
def eng(request):

tests/IVIMmodels/unit_tests/test_ivim_synthetic.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from src.wrappers.OsipiBase import OsipiBase
88
from utilities.data_simulation.GenerateData import GenerateData
99

10+
TRAINED_MODELS = {}
11+
1012
#run using pytest <path_to_this_file> --saveFileName test_output.txt --SNR 50 100 200
1113
#e.g. pytest -m slow tests/IVIMmodels/unit_tests/test_ivim_synthetic.py --saveFileName test_output.csv --SNR 10 50 100 200 --fitCount 20
1214
@pytest.mark.slow
@@ -15,9 +17,6 @@ def test_generated(algorithmlist, ivim_data, SNR, rtol, atol, fit_count, rician_
1517
ivim_algorithm, requires_matlab, deep_learning = algorithmlist
1618
if requires_matlab and eng is None:
1719
pytest.skip(reason="Running without matlab; if Matlab is available please run pytest --withmatlab")
18-
if deep_learning:
19-
pytest.skip(
20-
reason="Slow drifting in performance not yet implmented for deep learning algorithms") # requieres training a network per b-value set and inferencing all data in 1 go. So not 1 data point per time, but all data in 1 go :). Otherwise network will be trained many many times...
2120
rng = np.random.RandomState(42)
2221
# random.seed(42)
2322
S0 = 1
@@ -26,7 +25,16 @@ def test_generated(algorithmlist, ivim_data, SNR, rtol, atol, fit_count, rician_
2625
D = data["D"]
2726
f = data["f"]
2827
Dp = data["Dp"]
29-
fit = OsipiBase(algorithm=ivim_algorithm)
28+
if deep_learning:
29+
if ivim_algorithm+str(SNR) not in TRAINED_MODELS:
30+
print(f"Training deep learning model {ivim_algorithm} ...")
31+
fit = OsipiBase(bvalues=bvals, algorithm=ivim_algorithm,SNR=SNR)
32+
TRAINED_MODELS[ivim_algorithm+str(SNR)] = fit
33+
else:
34+
print(f"Reusing trained model {ivim_algorithm}")
35+
fit = TRAINED_MODELS[ivim_algorithm+str(SNR)]
36+
else:
37+
fit = OsipiBase(algorithm=ivim_algorithm)
3038
# here is a prior
3139
if use_prior and hasattr(fit, "supported_priors") and fit.supported_priors:
3240
prior = [rng.normal(D, D/3, 10), rng.normal(f, f/3, 10), rng.normal(Dp, Dp/3, 10), rng.normal(1, 1/3, 10)]

0 commit comments

Comments
 (0)