1313from abc import ABC , abstractmethod
1414from collections import namedtuple
1515from functools import wraps
16- from warnings import warn
1716
1817import dask .array as da
1918import numpy as np
2928 SERVICES_DIFFERENCE ,
3029)
3130from ..common .lenient import _lenient_service as lenient_service
31+ from ..config import get_logger
3232from ..coords import _DimensionalMetadata
3333from .. import exceptions
3434
3535
3636__all__ = [
3737 "Connectivity" ,
3838 "ConnectivityMetadata" ,
39+ "Mesh1DConnectivities" ,
40+ "Mesh2DConnectivities" ,
3941]
4042
4143
44+ # Configure the logger.
45+ logger = get_logger (__name__ , fmt = "[%(cls)s.%(funcName)s]" )
46+
47+ Mesh1DConnectivities = namedtuple ("Mesh1DConnectivities" , ["edge_node" ])
48+ Mesh2DConnectivities = namedtuple (
49+ "Mesh2DConnectivities" ,
50+ [
51+ "face_node" ,
52+ "edge_node" ,
53+ "face_edge" ,
54+ "face_face" ,
55+ "edge_face" ,
56+ "boundary_node" ,
57+ ],
58+ )
59+
60+
4261class Connectivity (_DimensionalMetadata ):
4362 """
4463 A CF-UGRID topology connectivity, describing the topological relationship
@@ -745,7 +764,6 @@ class _MeshConnectivityManagerMixin(ABC):
745764 REQUIRED = ()
746765 OPTIONAL = ()
747766 NDIM = NotImplemented
748- MembersTuple = NotImplemented
749767
750768 @abstractmethod
751769 def __init__ (self , * connectivities ):
@@ -755,18 +773,20 @@ def __init__(self, *connectivities):
755773 message = f"{ self .NDIM } D meshes require a { requisite } ."
756774 raise ValueError (message )
757775
758- self .valid_roles = self .REQUIRED + self .OPTIONAL
776+ self .ALL = self .REQUIRED + self .OPTIONAL
759777 self ._members = {}
760778 self .add (* connectivities )
761779
762780 def __iter__ (self ):
763- for member , connectivity in self ._members .items ():
764- yield member , connectivity
781+ for item in self ._members .items ():
782+ yield item
765783
766784 def __getstate__ (self ):
785+ # TBD
767786 pass
768787
769788 def __setstate__ (self , state ):
789+ # TBD
770790 pass
771791
772792 @property
@@ -785,7 +805,6 @@ def filter(
785805 edge = None ,
786806 face = None ,
787807 ):
788- # see Cube.coords for relevant patterns
789808 result = filter_cf (
790809 self ._members .values (),
791810 item = item ,
@@ -795,24 +814,26 @@ def filter(
795814 attributes = attributes ,
796815 )
797816
798- def location_filter (instances , parameter , location_name ):
799- result = instances
800- if parameter is True :
801- result = [
817+ def location_filter (instances_ , parameter_ , location_name_ ):
818+ if parameter_ is False :
819+ result_ = [
802820 instance_
803- for instance_ in result
804- if location_name
805- in (instance_ .src_location , instance_ .tgt_location )
821+ for instance_ in instances_
822+ if location_name_
823+ not in (instance_ .src_location , instance_ .tgt_location )
806824 ]
807- elif parameter is False :
808- result = [
825+ elif parameter_ is None :
826+ result_ = instances_
827+ else :
828+ # Interpret any other value as =True.
829+ result_ = [
809830 instance_
810- for instance_ in result
811- if location_name
812- not in (instance_ .src_location , instance_ .tgt_location )
831+ for instance_ in instances_
832+ if location_name_
833+ in (instance_ .src_location , instance_ .tgt_location )
813834 ]
814835
815- return result
836+ return result_
816837
817838 for parameter , location_name in (
818839 (node , "node" ),
@@ -821,7 +842,8 @@ def location_filter(instances, parameter, location_name):
821842 ):
822843 result = location_filter (result , parameter , location_name )
823844
824- return result
845+ result_dict = {k : v for k , v in self ._members .items () if v in result }
846+ return result_dict
825847
826848 def filter_single (self , ** kwargs ):
827849 result = self .filter (** kwargs )
@@ -856,14 +878,18 @@ def add(self, *connectivities):
856878 # validate their outputs.
857879 add_dict = {}
858880 for connectivity in connectivities :
859- assert isinstance (connectivity , Connectivity )
881+ if not isinstance (connectivity , Connectivity ):
882+ message = f"Expected Connectivity, got: { type (connectivity )} ."
883+ raise ValueError (message )
860884 cf_role = connectivity .cf_role
861- if cf_role not in self .valid_roles :
885+ if cf_role not in self .ALL :
862886 message = (
863887 f"Connectivity not added. Got cf_role={ cf_role } . "
864- f"Expected one of: { self .valid_roles } ."
888+ f"Expected one of: { self .ALL } ."
889+ )
890+ logger .warning (
891+ message , extra = dict (cls = self .__class__ .__name__ )
865892 )
866- warn (message )
867893 else :
868894 add_dict [cf_role ] = connectivity
869895
@@ -916,46 +942,41 @@ def remove(
916942 f"Connectivity not removed: { cf_role } - required "
917943 f"for a valid { self .NDIM } D Mesh."
918944 )
919- warn (message )
945+ logger .warning (
946+ message , extra = dict (cls = self .__class__ .__name__ )
947+ )
920948
921949 for cf_role in removal_dict .keys ():
922950 del self ._members [cf_role ]
923951
924952 return removal_dict
925953
926954 def __repr__ (self ):
927- class_name = type (self ).__name__
928- content = ", " .join (
929- f"{ member } ={ connectivity } " for member , connectivity in self
930- )
931- return ", " .join ((class_name , content ))
955+ args = [f"{ member } ={ connectivity } " for member , connectivity in self ]
956+ return f"{ self .__class__ .__name__ } ({ ', ' .join (args )} )"
932957
933958 def __eq__ (self , other ):
934- # Full equality could be MASSIVE, so we want to avoid that.
935- # Ideally we want a mesh signature from LFRic for comparison, although this would
936- # limit Iris' relevance outside MO.
937- # TL;DR: unknown quantity.
938- raise NotImplementedError
959+ # TBD
960+ return NotImplemented
939961
940962 def __ne__ (self , other ):
941- # See __eq__
942- raise NotImplementedError
963+ # TBD
964+ return NotImplemented
943965
944966
945967# keep an eye on the __init__ inheritance
946968class _Mesh1DConnectivityManager (_MeshConnectivityManagerMixin ):
947969 REQUIRED = ("edge_node_connectivity" ,)
948970 OPTIONAL = ()
949971 NDIM = 1
950- MembersTuple = namedtuple ("Mesh1DConnectivities" , ["edge_node" ])
951972
952973 def __init__ (self , * connectivities ):
953974 super ().__init__ (* connectivities )
954975
955- # TODO: debatable whether one couldn't just use self.filter().
976+ # TODO: debatable whether a user couldn't just use self.filter() with no args .
956977 @property
957978 def all_members (self ):
958- return self . MembersTuple (self .edge_node )
979+ return Mesh1DConnectivities (self .edge_node )
959980
960981 @property
961982 def edge_node (self ):
@@ -972,25 +993,14 @@ class _Mesh2DConnectivityManager(_MeshConnectivityManagerMixin):
972993 "boundary_node_connectivity" ,
973994 )
974995 NDIM = 2
975- MembersTuple = namedtuple (
976- "Mesh2DConnectivities" ,
977- [
978- "face_node" ,
979- "edge_node" ,
980- "face_edge" ,
981- "face_face" ,
982- "edge_face" ,
983- "boundary_node" ,
984- ],
985- )
986996
987997 def __init__ (self , * connectivities ):
988998 super ().__init__ (* connectivities )
989999
990- # TODO: debatable whether one couldn't just use self.filter().
1000+ # TODO: debatable whether a user couldn't just use self.filter() with no args .
9911001 @property
9921002 def all_members (self ):
993- return self . MembersTuple (
1003+ return Mesh2DConnectivities (
9941004 self .face_node ,
9951005 self .edge_node ,
9961006 self .face_edge ,
0 commit comments