Skip to content

Commit 515ba78

Browse files
authored
Expose entity additional codes (#2879)
1 parent 13a6d8d commit 515ba78

File tree

2 files changed

+60
-13
lines changed

2 files changed

+60
-13
lines changed

holidays/registry.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import importlib
1414
from collections.abc import Iterable
1515
from threading import RLock
16-
from typing import Any, Optional, Union
16+
from typing import Any, Optional
1717

1818
from holidays.holiday_base import HolidayBase
1919

@@ -339,35 +339,46 @@ def get_entity(self) -> Optional[HolidayBase]:
339339
@staticmethod
340340
def _get_entity_codes(
341341
container: RegistryDict,
342-
entity_length: Union[int, Iterable[int]],
343342
include_aliases: bool = True,
343+
max_code_length: int = 3,
344+
min_code_length: int = 2,
344345
) -> Iterable[str]:
345-
entity_length = {entity_length} if isinstance(entity_length, int) else set(entity_length)
346346
for entities in container.values():
347-
for entity in entities:
348-
if len(entity) in entity_length:
349-
yield entity
350-
# Assuming that the alpha-2 code goes first.
351-
if not include_aliases:
352-
break
347+
for code in entities[1:]:
348+
if min_code_length <= len(code) <= max_code_length:
349+
yield code
350+
351+
# Stop after the first matching code if aliases are not requested.
352+
# Assuming that the alpha-2 code goes first.
353+
if not include_aliases:
354+
break
353355

354356
@staticmethod
355357
def get_country_codes(include_aliases: bool = True) -> Iterable[str]:
356358
"""Get supported country codes.
357359
358360
:param include_aliases:
359-
Whether to include entity aliases (e.g. UK for GB).
361+
Whether to include entity aliases (e.g. GBR and UK for GB,
362+
UKR for UA, USA for US, etc).
360363
"""
361-
return EntityLoader._get_entity_codes(COUNTRIES, 2, include_aliases)
364+
return EntityLoader._get_entity_codes(
365+
COUNTRIES,
366+
include_aliases=include_aliases,
367+
)
362368

363369
@staticmethod
364370
def get_financial_codes(include_aliases: bool = True) -> Iterable[str]:
365371
"""Get supported financial codes.
366372
367373
:param include_aliases:
368-
Whether to include entity aliases(e.g. TAR for ECB, XNYS for NYSE).
374+
Whether to include entity aliases (e.g. B3 for BVMF,
375+
TAR for ECB, NYSE for XNYS, etc).
369376
"""
370-
return EntityLoader._get_entity_codes(FINANCIAL, (3, 4), include_aliases)
377+
return EntityLoader._get_entity_codes(
378+
FINANCIAL,
379+
include_aliases=include_aliases,
380+
max_code_length=4,
381+
)
371382

372383
@staticmethod
373384
def load(prefix: str, scope: dict) -> None:

tests/test_registry.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,42 @@ def test_financial_str(self):
108108
"'holidays.financial.ny_stock_exchange.NYSE' class directly.",
109109
)
110110

111+
def test_get_country_codes(self):
112+
country_codes = set(registry.EntityLoader.get_country_codes(include_aliases=False))
113+
for entity_classes in registry.COUNTRIES.values():
114+
self.assertNotIn(entity_classes[0], country_codes)
115+
self.assertIn(entity_classes[1], country_codes)
116+
for code in entity_classes[2:]:
117+
self.assertNotIn(code, country_codes)
118+
119+
def test_get_country_codes_aliases(self):
120+
country_codes = set(registry.EntityLoader.get_country_codes(include_aliases=True))
121+
for entity_classes in registry.COUNTRIES.values():
122+
self.assertNotIn(entity_classes[0], country_codes)
123+
for code in entity_classes[1:]:
124+
if code.isupper():
125+
self.assertIn(code, country_codes)
126+
else:
127+
self.assertNotIn(code, country_codes)
128+
129+
def test_get_financial_codes(self):
130+
financial_codes = set(registry.EntityLoader.get_financial_codes(include_aliases=False))
131+
for entity_classes in registry.FINANCIAL.values():
132+
self.assertNotIn(entity_classes[0], financial_codes)
133+
self.assertIn(entity_classes[1], financial_codes)
134+
for code in entity_classes[2:]:
135+
self.assertNotIn(code, financial_codes)
136+
137+
def test_get_financial_codes_aliases(self):
138+
financial_codes = set(registry.EntityLoader.get_financial_codes(include_aliases=True))
139+
for entity_classes in registry.FINANCIAL.values():
140+
self.assertNotIn(entity_classes[0], financial_codes)
141+
for code in entity_classes[1:]:
142+
if code.isupper():
143+
self.assertIn(code, financial_codes)
144+
else:
145+
self.assertNotIn(code, financial_codes)
146+
111147
def test_inheritance(self):
112148
def create_instance(parent):
113149
class SubClass(parent):

0 commit comments

Comments
 (0)