@@ -49,25 +49,25 @@ def get_nnunet_trainer(
4949 The returned nnUNet trainer can be used to initialize the SupervisedTrainer for training, including the network,
5050 optimizer, loss function, DataLoader, etc.
5151
52- ```python
53- from monai.apps import SupervisedTrainer
54- from monai.bundle.nnunet import get_nnunet_trainer
55-
56- dataset_name_or_id = 'Task101_PROSTATE'
57- fold = 0
58- configuration = '3d_fullres'
59- nnunet_trainer = get_nnunet_trainer(dataset_name_or_id, configuration, fold)
60-
61- trainer = SupervisedTrainer(
62- device=nnunet_trainer.device,
63- max_epochs =nnunet_trainer.num_epochs ,
64- train_data_loader =nnunet_trainer.dataloader_train ,
65- network =nnunet_trainer.network ,
66- optimizer =nnunet_trainer.optimizer ,
67- loss_function =nnunet_trainer.loss_function ,
68- epoch_length =nnunet_trainer.num_iterations_per_epoch ,
69-
70- ```
52+ Example::
53+
54+ from monai.apps import SupervisedTrainer
55+ from monai.bundle.nnunet import get_nnunet_trainer
56+
57+ dataset_name_or_id = 'Task101_PROSTATE'
58+ fold = 0
59+ configuration = '3d_fullres'
60+ nnunet_trainer = get_nnunet_trainer(dataset_name_or_id, configuration, fold)
61+
62+ trainer = SupervisedTrainer(
63+ device =nnunet_trainer.device ,
64+ max_epochs =nnunet_trainer.num_epochs ,
65+ train_data_loader =nnunet_trainer.dataloader_train ,
66+ network =nnunet_trainer.network ,
67+ optimizer =nnunet_trainer.optimizer ,
68+ loss_function =nnunet_trainer.loss_function ,
69+ epoch_length=nnunet_trainer.num_iterations_per_epoch,
70+ )
7171
7272 Parameters
7373 ----------
@@ -162,16 +162,19 @@ class ModelnnUNetWrapper(torch.nn.Module):
162162 The folder path where the model and related files are stored.
163163 model_name : str, optional
164164 The name of the model file, by default "model.pt".
165+
165166 Attributes
166167 ----------
167- predictor : object
168- The predictor object used for inference.
168+ predictor : nnUNetPredictor
169+ The nnUNet predictor object used for inference.
169170 network_weights : torch.nn.Module
170171 The network weights of the model.
172+
171173 Methods
172174 -------
173175 forward(x)
174176 Perform forward pass and prediction on the input data.
177+
175178 Notes
176179 -----
177180 This class integrates nnUNet model with MONAI framework by loading necessary configurations,
@@ -183,7 +186,7 @@ def __init__(self, predictor, model_folder, model_name="model.pt"):
183186 self .predictor = predictor
184187
185188 model_training_output_dir = model_folder
186- use_folds = "0"
189+ use_folds = [ "0" ]
187190
188191 from nnunetv2 .utilities .plans_handling .plans_handler import PlansManager
189192
@@ -290,27 +293,28 @@ def forward(self, x):
290293
291294def get_nnunet_monai_predictor (model_folder , model_name = "model.pt" ):
292295 """
293- Initializes and returns a nnUNetMONAIModelWrapper with a nnUNetPredictor.
296+ Initializes and returns a ` nnUNetMONAIModelWrapper` containing the corresponding ` nnUNetPredictor` .
294297 The model folder should contain the following files, created during training:
295- - dataset.json: from the nnUNet results folder.
296- - plans .json: from the nnUNet results folder.
297- - nnunet_checkpoint.pth: The nnUNet checkpoint file, containing the nnUNet training configuration
298- (`init_kwargs`, `trainer_name`, `inference_allowed_mirroring_axes`).
299- - model.pt: The checkpoint file containing the model weights.
300-
298+
299+ - dataset .json: from the nnUNet results folder
300+ - plans.json: from the nnUNet results folder
301+ - nnunet_checkpoint.pth: The nnUNet checkpoint file, containing the nnUNet training configuration (`init_kwargs`, `trainer_name`, `inference_allowed_mirroring_axes`)
302+ - model.pt: The checkpoint file containing the model weights.
303+
301304 The returned wrapper object can be used for inference with MONAI framework:
302- ```python
303- from monai.bundle.nnunet import get_nnunet_monai_predictor
305+
306+ Example::
307+
308+ from monai.bundle.nnunet import get_nnunet_monai_predictor
304309
305- model_folder = 'path/to/monai_bundle/model'
306- model_name = 'model.pt'
307- wrapper = get_nnunet_monai_predictor(model_folder, model_name)
310+ model_folder = 'path/to/monai_bundle/model'
311+ model_name = 'model.pt'
312+ wrapper = get_nnunet_monai_predictor(model_folder, model_name)
308313
309- # Perform inference
310- input_data = ...
311- output = wrapper(input_data)
314+ # Perform inference
315+ input_data = ...
316+ output = wrapper(input_data)
312317
313- ```
314318
315319 Parameters
316320 ----------
0 commit comments