|
13 | 13 | import importlib |
14 | 14 | from collections.abc import Iterable |
15 | 15 | from threading import RLock |
16 | | -from typing import Any, Optional, Union |
| 16 | +from typing import Any, Optional |
17 | 17 |
|
18 | 18 | from holidays.holiday_base import HolidayBase |
19 | 19 |
|
@@ -339,35 +339,46 @@ def get_entity(self) -> Optional[HolidayBase]: |
339 | 339 | @staticmethod |
340 | 340 | def _get_entity_codes( |
341 | 341 | container: RegistryDict, |
342 | | - entity_length: Union[int, Iterable[int]], |
343 | 342 | include_aliases: bool = True, |
| 343 | + max_code_length: int = 3, |
| 344 | + min_code_length: int = 2, |
344 | 345 | ) -> Iterable[str]: |
345 | | - entity_length = {entity_length} if isinstance(entity_length, int) else set(entity_length) |
346 | 346 | for entities in container.values(): |
347 | | - for entity in entities: |
348 | | - if len(entity) in entity_length: |
349 | | - yield entity |
350 | | - # Assuming that the alpha-2 code goes first. |
351 | | - if not include_aliases: |
352 | | - break |
| 347 | + for code in entities[1:]: |
| 348 | + if min_code_length <= len(code) <= max_code_length: |
| 349 | + yield code |
| 350 | + |
| 351 | + # Stop after the first matching code if aliases are not requested. |
| 352 | + # Assuming that the alpha-2 code goes first. |
| 353 | + if not include_aliases: |
| 354 | + break |
353 | 355 |
|
354 | 356 | @staticmethod |
355 | 357 | def get_country_codes(include_aliases: bool = True) -> Iterable[str]: |
356 | 358 | """Get supported country codes. |
357 | 359 |
|
358 | 360 | :param include_aliases: |
359 | | - Whether to include entity aliases (e.g. UK for GB). |
| 361 | + Whether to include entity aliases (e.g. GBR and UK for GB, |
| 362 | + UKR for UA, USA for US, etc). |
360 | 363 | """ |
361 | | - return EntityLoader._get_entity_codes(COUNTRIES, 2, include_aliases) |
| 364 | + return EntityLoader._get_entity_codes( |
| 365 | + COUNTRIES, |
| 366 | + include_aliases=include_aliases, |
| 367 | + ) |
362 | 368 |
|
363 | 369 | @staticmethod |
364 | 370 | def get_financial_codes(include_aliases: bool = True) -> Iterable[str]: |
365 | 371 | """Get supported financial codes. |
366 | 372 |
|
367 | 373 | :param include_aliases: |
368 | | - Whether to include entity aliases(e.g. TAR for ECB, XNYS for NYSE). |
| 374 | + Whether to include entity aliases (e.g. B3 for BVMF, |
| 375 | + TAR for ECB, NYSE for XNYS, etc). |
369 | 376 | """ |
370 | | - return EntityLoader._get_entity_codes(FINANCIAL, (3, 4), include_aliases) |
| 377 | + return EntityLoader._get_entity_codes( |
| 378 | + FINANCIAL, |
| 379 | + include_aliases=include_aliases, |
| 380 | + max_code_length=4, |
| 381 | + ) |
371 | 382 |
|
372 | 383 | @staticmethod |
373 | 384 | def load(prefix: str, scope: dict) -> None: |
|
0 commit comments