Skip to content

Commit 004c289

Browse files
committed
ConnectivityManager align with proposed CoordManager.
1 parent 99dbee9 commit 004c289

File tree

2 files changed

+66
-54
lines changed

2 files changed

+66
-54
lines changed

lib/iris/common/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,9 @@ def filter_cf(
6868

6969
def attr_filter(instance_):
7070
return all(
71-
k in instance_.attributes and instance_.attributes[k] == v
71+
k in instance_.attributes
72+
and metadata._hexdigest(instance_.attributes[k])
73+
== metadata._hexdigest(v)
7274
for k, v in attributes.items()
7375
)
7476

lib/iris/experimental/ugrid.py

Lines changed: 63 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from abc import ABC, abstractmethod
1414
from collections import namedtuple
1515
from functools import wraps
16-
from warnings import warn
1716

1817
import dask.array as da
1918
import numpy as np
@@ -29,16 +28,36 @@
2928
SERVICES_DIFFERENCE,
3029
)
3130
from ..common.lenient import _lenient_service as lenient_service
31+
from ..config import get_logger
3232
from ..coords import _DimensionalMetadata
3333
from .. import exceptions
3434

3535

3636
__all__ = [
3737
"Connectivity",
3838
"ConnectivityMetadata",
39+
"Mesh1DConnectivities",
40+
"Mesh2DConnectivities",
3941
]
4042

4143

44+
# Configure the logger.
45+
logger = get_logger(__name__, fmt="[%(cls)s.%(funcName)s]")
46+
47+
Mesh1DConnectivities = namedtuple("Mesh1DConnectivities", ["edge_node"])
48+
Mesh2DConnectivities = namedtuple(
49+
"Mesh2DConnectivities",
50+
[
51+
"face_node",
52+
"edge_node",
53+
"face_edge",
54+
"face_face",
55+
"edge_face",
56+
"boundary_node",
57+
],
58+
)
59+
60+
4261
class Connectivity(_DimensionalMetadata):
4362
"""
4463
A CF-UGRID topology connectivity, describing the topological relationship
@@ -745,7 +764,6 @@ class _MeshConnectivityManagerMixin(ABC):
745764
REQUIRED = ()
746765
OPTIONAL = ()
747766
NDIM = NotImplemented
748-
MembersTuple = NotImplemented
749767

750768
@abstractmethod
751769
def __init__(self, *connectivities):
@@ -755,18 +773,20 @@ def __init__(self, *connectivities):
755773
message = f"{self.NDIM}D meshes require a {requisite}."
756774
raise ValueError(message)
757775

758-
self.valid_roles = self.REQUIRED + self.OPTIONAL
776+
self.ALL = self.REQUIRED + self.OPTIONAL
759777
self._members = {}
760778
self.add(*connectivities)
761779

762780
def __iter__(self):
763-
for member, connectivity in self._members.items():
764-
yield member, connectivity
781+
for item in self._members.items():
782+
yield item
765783

766784
def __getstate__(self):
785+
# TBD
767786
pass
768787

769788
def __setstate__(self, state):
789+
# TBD
770790
pass
771791

772792
@property
@@ -785,7 +805,6 @@ def filter(
785805
edge=None,
786806
face=None,
787807
):
788-
# see Cube.coords for relevant patterns
789808
result = filter_cf(
790809
self._members.values(),
791810
item=item,
@@ -795,24 +814,26 @@ def filter(
795814
attributes=attributes,
796815
)
797816

798-
def location_filter(instances, parameter, location_name):
799-
result = instances
800-
if parameter is True:
801-
result = [
817+
def location_filter(instances_, parameter_, location_name_):
818+
if parameter_ is False:
819+
result_ = [
802820
instance_
803-
for instance_ in result
804-
if location_name
805-
in (instance_.src_location, instance_.tgt_location)
821+
for instance_ in instances_
822+
if location_name_
823+
not in (instance_.src_location, instance_.tgt_location)
806824
]
807-
elif parameter is False:
808-
result = [
825+
elif parameter_ is None:
826+
result_ = instances_
827+
else:
828+
# Interpret any other value as =True.
829+
result_ = [
809830
instance_
810-
for instance_ in result
811-
if location_name
812-
not in (instance_.src_location, instance_.tgt_location)
831+
for instance_ in instances_
832+
if location_name_
833+
in (instance_.src_location, instance_.tgt_location)
813834
]
814835

815-
return result
836+
return result_
816837

817838
for parameter, location_name in (
818839
(node, "node"),
@@ -821,7 +842,8 @@ def location_filter(instances, parameter, location_name):
821842
):
822843
result = location_filter(result, parameter, location_name)
823844

824-
return result
845+
result_dict = {k: v for k, v in self._members.items() if v in result}
846+
return result_dict
825847

826848
def filter_single(self, **kwargs):
827849
result = self.filter(**kwargs)
@@ -856,14 +878,18 @@ def add(self, *connectivities):
856878
# validate their outputs.
857879
add_dict = {}
858880
for connectivity in connectivities:
859-
assert isinstance(connectivity, Connectivity)
881+
if not isinstance(connectivity, Connectivity):
882+
message = f"Expected Connectivity, got: {type(connectivity)} ."
883+
raise ValueError(message)
860884
cf_role = connectivity.cf_role
861-
if cf_role not in self.valid_roles:
885+
if cf_role not in self.ALL:
862886
message = (
863887
f"Connectivity not added. Got cf_role={cf_role} . "
864-
f"Expected one of: {self.valid_roles} ."
888+
f"Expected one of: {self.ALL} ."
889+
)
890+
logger.warning(
891+
message, extra=dict(cls=self.__class__.__name__)
865892
)
866-
warn(message)
867893
else:
868894
add_dict[cf_role] = connectivity
869895

@@ -916,46 +942,41 @@ def remove(
916942
f"Connectivity not removed: {cf_role} - required "
917943
f"for a valid {self.NDIM}D Mesh."
918944
)
919-
warn(message)
945+
logger.warning(
946+
message, extra=dict(cls=self.__class__.__name__)
947+
)
920948

921949
for cf_role in removal_dict.keys():
922950
del self._members[cf_role]
923951

924952
return removal_dict
925953

926954
def __repr__(self):
927-
class_name = type(self).__name__
928-
content = ", ".join(
929-
f"{member}={connectivity}" for member, connectivity in self
930-
)
931-
return ", ".join((class_name, content))
955+
args = [f"{member}={connectivity}" for member, connectivity in self]
956+
return f"{self.__class__.__name__}({', '.join(args)})"
932957

933958
def __eq__(self, other):
934-
# Full equality could be MASSIVE, so we want to avoid that.
935-
# Ideally we want a mesh signature from LFRic for comparison, although this would
936-
# limit Iris' relevance outside MO.
937-
# TL;DR: unknown quantity.
938-
raise NotImplementedError
959+
# TBD
960+
return NotImplemented
939961

940962
def __ne__(self, other):
941-
# See __eq__
942-
raise NotImplementedError
963+
# TBD
964+
return NotImplemented
943965

944966

945967
# keep an eye on the __init__ inheritance
946968
class _Mesh1DConnectivityManager(_MeshConnectivityManagerMixin):
947969
REQUIRED = ("edge_node_connectivity",)
948970
OPTIONAL = ()
949971
NDIM = 1
950-
MembersTuple = namedtuple("Mesh1DConnectivities", ["edge_node"])
951972

952973
def __init__(self, *connectivities):
953974
super().__init__(*connectivities)
954975

955-
# TODO: debatable whether one couldn't just use self.filter().
976+
# TODO: debatable whether a user couldn't just use self.filter() with no args.
956977
@property
957978
def all_members(self):
958-
return self.MembersTuple(self.edge_node)
979+
return Mesh1DConnectivities(self.edge_node)
959980

960981
@property
961982
def edge_node(self):
@@ -972,25 +993,14 @@ class _Mesh2DConnectivityManager(_MeshConnectivityManagerMixin):
972993
"boundary_node_connectivity",
973994
)
974995
NDIM = 2
975-
MembersTuple = namedtuple(
976-
"Mesh2DConnectivities",
977-
[
978-
"face_node",
979-
"edge_node",
980-
"face_edge",
981-
"face_face",
982-
"edge_face",
983-
"boundary_node",
984-
],
985-
)
986996

987997
def __init__(self, *connectivities):
988998
super().__init__(*connectivities)
989999

990-
# TODO: debatable whether one couldn't just use self.filter().
1000+
# TODO: debatable whether a user couldn't just use self.filter() with no args.
9911001
@property
9921002
def all_members(self):
993-
return self.MembersTuple(
1003+
return Mesh2DConnectivities(
9941004
self.face_node,
9951005
self.edge_node,
9961006
self.face_edge,

0 commit comments

Comments
 (0)