Skip to content

Commit dfaa6a0

Browse files
committed
Skeletion factory class
1 parent e7feedf commit dfaa6a0

File tree

5 files changed

+296
-2
lines changed

5 files changed

+296
-2
lines changed

monai/networks/layers/factories.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,12 @@ def use_factory(fact_args):
6868
import torch.nn as nn
6969

7070
from monai.networks.utils import has_nvfuser_instance_norm
71-
from monai.utils import look_up_option, optional_import
71+
from monai.utils import Factory, look_up_option, optional_import
7272

7373
__all__ = ["LayerFactory", "Dropout", "Norm", "Act", "Conv", "Pool", "Pad", "split_args"]
7474

7575

76-
class LayerFactory:
76+
class LayerFactory(Factory):
7777
"""
7878
Factory object for creating layers, this uses given factory functions to actually produce the types or constructing
7979
callables. These functions are referred to by name and can be added at any time.

monai/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
# have to explicitly bring these in here to resolve circular import issues
1515
from .aliases import alias, resolve_name
16+
from .component_store import ComponentStore
1617
from .decorators import MethodReplacer, RestartGenerator
1718
from .deprecate_utils import DeprecatedError, deprecated, deprecated_arg, deprecated_arg_default
1819
from .dist import RankFilter, evenly_divisible_all_gather, get_dist_device, string_list_all_gather
@@ -61,6 +62,7 @@
6162
Weight,
6263
WSIPatchKeys,
6364
)
65+
from .factory import Factory
6466
from .jupyter_utils import StatusMembers, ThreadContainer
6567
from .misc import (
6668
MAX_SEED,

monai/utils/component_store.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
from collections import namedtuple
15+
from keyword import iskeyword
16+
from textwrap import dedent, indent
17+
from typing import Any, Callable, Iterable, TypeVar
18+
19+
T = TypeVar("T")
20+
from monai.utils.factory import Factory
21+
22+
23+
def is_variable(name):
24+
"""Returns True if `name` is a valid Python variable name and also not a keyword."""
25+
return name.isidentifier() and not iskeyword(name)
26+
27+
28+
class ComponentStore(Factory):
29+
"""
30+
Represents a storage object for other objects (specifically functions) keyed to a name with a description.
31+
32+
These objects act as global named places for storing components for objects parameterised by component names.
33+
Typically this is functions although other objects can be added. Printing a component store will produce a
34+
list of members along with their docstring information if present.
35+
36+
Example:
37+
38+
.. code-block:: python
39+
40+
TestStore = ComponentStore("Test Store", "A test store for demo purposes")
41+
42+
@TestStore.add_def("my_func_name", "Some description of your function")
43+
def _my_func(a, b):
44+
'''A description of your function here.'''
45+
return a * b
46+
47+
print(TestStore) # will print out name, description, and 'my_func_name' with the docstring
48+
49+
func = TestStore["my_func_name"]
50+
result = func(7, 6)
51+
52+
"""
53+
54+
_Component = namedtuple("Component", ("description", "value")) # internal value pair
55+
56+
def __init__(self, name: str, description: str) -> None:
57+
self.components: dict[str, self._Component] = {}
58+
self.name: str = name
59+
self.description: str = description
60+
61+
self.__doc__ = f"Component Store '{name}': {description}\n{self.__doc__ or ''}".strip()
62+
63+
def add(self, name: str, desc: str, value: T) -> T:
64+
"""Store the object `value` under the name `name` with description `desc`."""
65+
if not is_variable(name):
66+
raise ValueError("Name of component must be valid Python identifier")
67+
68+
self.components[name] = self._Component(desc, value)
69+
return value
70+
71+
def add_def(self, name: str, desc: str) -> Callable:
72+
"""Returns a decorator which stores the decorated function under `name` with description `desc`."""
73+
74+
def deco(func):
75+
"""Decorator to add a function to a store."""
76+
return self.add(name, desc, func)
77+
78+
return deco
79+
80+
def __contains__(self, name: str) -> bool:
81+
"""Returns True if the given name is stored."""
82+
return name in self.components
83+
84+
def __len__(self) -> int:
85+
"""Returns the number of stored components."""
86+
return len(self.components)
87+
88+
def __iter__(self) -> Iterable:
89+
"""Yields name/component pairs."""
90+
for k, v in self.components.items():
91+
yield k, v.value
92+
93+
def __str__(self):
94+
result = f"Component Store '{self.name}': {self.description}\nAvailable components:"
95+
for k, v in self.components.items():
96+
result += f"\n* {k}:"
97+
98+
if hasattr(v.value, "__doc__"):
99+
doc = indent(dedent(v.value.__doc__.lstrip("\n").rstrip()), " ")
100+
result += f"\n{doc}\n"
101+
else:
102+
result += f" {v.description}"
103+
104+
return result
105+
106+
def __getattr__(self, name: str) -> Any:
107+
"""Returns the stored object under the given name."""
108+
if name in self.components:
109+
return self.components[name].value
110+
else:
111+
return self.__getattribute__(name)
112+
113+
def __getitem__(self, name: str) -> Any:
114+
"""Returns the stored object under the given name."""
115+
if name in self.components:
116+
return self.components[name].value
117+
else:
118+
raise ValueError(f"Component '{name}' not found")

monai/utils/factory.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
"""
12+
Defines a generic factory class.
13+
"""
14+
15+
from __future__ import annotations
16+
17+
18+
class Factory:
19+
"""
20+
Baseline factory object.
21+
"""
22+
23+
# def __init__(self) -> None:
24+
# self.factories: dict[str, Callable] = {}
25+
#
26+
# @property
27+
# def names(self) -> tuple[str, ...]:
28+
# """
29+
# Produces all factory names.
30+
# """
31+
#
32+
# return tuple(self.factories)
33+
#
34+
# def add_factory_callable(self, name: str, func: Callable) -> None:
35+
# """
36+
# Add the factory function to this object under the given name.
37+
# """
38+
#
39+
# self.factories[name.upper()] = func
40+
# self.__doc__ = (
41+
# "The supported member"
42+
# + ("s are: " if len(self.names) > 1 else " is: ")
43+
# + ", ".join(f"``{name}``" for name in self.names)
44+
# + ".\nPlease see :py:class:`monai.networks.layers.split_args` for additional args parsing."
45+
# )
46+
#
47+
# def factory_function(self, name: str) -> Callable:
48+
# """
49+
# Decorator for adding a factory function with the given name.
50+
# """
51+
#
52+
# def _add(func: Callable) -> Callable:
53+
# self.add_factory_callable(name, func)
54+
# return func
55+
#
56+
# return _add
57+
#
58+
# def get_constructor(self, factory_name: str, *args) -> Any:
59+
# """
60+
# Get the constructor for the given factory name and arguments.
61+
#
62+
# Raises:
63+
# TypeError: When ``factory_name`` is not a ``str``.
64+
#
65+
# """
66+
#
67+
# if not isinstance(factory_name, str):
68+
# raise TypeError(f"factory_name must a str but is {type(factory_name).__name__}.")
69+
#
70+
# func = look_up_option(factory_name.upper(), self.factories)
71+
# return func(*args)
72+
#
73+
# def __getitem__(self, args) -> Any:
74+
# """
75+
# Get the given name or name/arguments pair. If `args` is a callable it is assumed to be the constructor
76+
# itself and is returned, otherwise it should be the factory name or a pair containing the name and arguments.
77+
# """
78+
#
79+
# # `args[0]` is actually a type or constructor
80+
# if callable(args):
81+
# return args
82+
#
83+
# # `args` is a factory name or a name with arguments
84+
# if isinstance(args, str):
85+
# name_obj, args = args, ()
86+
# else:
87+
# name_obj, *args = args
88+
#
89+
# return self.get_constructor(name_obj, *args)
90+
#
91+
# def __getattr__(self, key):
92+
# """
93+
# If `key` is a factory name, return it, otherwise behave as inherited. This allows referring to factory names
94+
# as if they were constants, eg. `Fact.FOO` for a factory Fact with factory function foo.
95+
# """
96+
#
97+
# if key in self.factories:
98+
# return key
99+
#
100+
# return super().__getattribute__(key)
101+
#
102+
#

tests/test_component_store.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import unittest
15+
16+
from monai.utils import ComponentStore
17+
18+
19+
class TestComponentStore(unittest.TestCase):
20+
def setUp(self):
21+
self.cs = ComponentStore("TestStore", "I am a test store, please ignore")
22+
23+
def test_empty(self):
24+
self.assertEqual(len(self.cs), 0)
25+
self.assertEqual(list(self.cs), [])
26+
27+
def test_add(self):
28+
test_obj = object()
29+
30+
self.assertFalse("test_obj" in self.cs)
31+
32+
self.cs.add("test_obj", "Test object", test_obj)
33+
34+
self.assertTrue("test_obj" in self.cs)
35+
36+
self.assertEqual(len(self.cs), 1)
37+
self.assertEqual(list(self.cs), [("test_obj", test_obj)])
38+
39+
self.assertEqual(self.cs.test_obj, test_obj)
40+
self.assertEqual(self.cs["test_obj"], test_obj)
41+
42+
def test_add2(self):
43+
test_obj1 = object()
44+
test_obj2 = object()
45+
46+
self.cs.add("test_obj1", "Test object", test_obj1)
47+
self.cs.add("test_obj2", "Test object", test_obj2)
48+
49+
self.assertEqual(len(self.cs), 2)
50+
self.assertTrue("test_obj1" in self.cs)
51+
self.assertTrue("test_obj2" in self.cs)
52+
53+
def test_add_def(self):
54+
self.assertFalse("test_func" in self.cs)
55+
56+
@self.cs.add_def("test_func", "Test function")
57+
def test_func():
58+
return 123
59+
60+
self.assertTrue("test_func" in self.cs)
61+
62+
self.assertEqual(len(self.cs), 1)
63+
self.assertEqual(list(self.cs), [("test_func", test_func)])
64+
65+
self.assertEqual(self.cs.test_func, test_func)
66+
self.assertEqual(self.cs["test_func"], test_func)
67+
68+
# try adding the same function again
69+
self.cs.add_def("test_func", "Test function but with new description")(test_func)
70+
71+
self.assertEqual(len(self.cs), 1)
72+
self.assertEqual(self.cs.test_func, test_func)

0 commit comments

Comments
 (0)