|
3 | 3 | # SPDX-License-Identifier: GPL-3.0-or-later |
4 | 4 |
|
5 | 5 | from pathlib import Path |
6 | | -from typing import Any, Iterator, Literal, Tuple, Union |
| 6 | +from typing import Any, Literal, Tuple, Union |
7 | 7 |
|
8 | 8 | import tomlkit |
9 | | -from tomlkit.items import Table |
| 9 | +from tomlkit.toml_document import TOMLDocument |
10 | 10 |
|
11 | 11 | from .._errors import VersionError |
12 | 12 | from .._version import Version, VersionUpdate |
|
16 | 16 | class CargoVersionCommand(VersionCommand): |
17 | 17 | project_file_name = "Cargo.toml" |
18 | 18 |
|
| 19 | + def __has_package_version(self, toml: TOMLDocument): |
| 20 | + """ |
| 21 | + Checks if the 'package' table contains a 'version'. |
| 22 | + """ |
| 23 | + # get a pure python object (recursively) |
| 24 | + toml_dict = toml.unwrap() |
| 25 | + return "package" in toml_dict and "version" in toml_dict["package"] |
| 26 | + |
| 27 | + def __has_workspace_package_version(self, toml: TOMLDocument): |
| 28 | + """ |
| 29 | + Checks if the 'workspace.package' table contains a 'version'. |
| 30 | + """ |
| 31 | + # get a pure python object (recursively) |
| 32 | + toml_dict = toml.unwrap() |
| 33 | + return ( |
| 34 | + "workspace" in toml_dict |
| 35 | + and "package" in toml_dict["workspace"] |
| 36 | + and "version" in toml_dict["workspace"]["package"] |
| 37 | + ) |
| 38 | + |
19 | 39 | def __as_project_document( |
20 | 40 | self, origin: Path |
21 | | - ) -> Iterator[Tuple[Path, tomlkit.TOMLDocument],]: |
| 41 | + ) -> Tuple[Path, tomlkit.TOMLDocument]: |
22 | 42 | """ |
23 | | - Parse the given origin and yields a tuple of path to a |
24 | | - Cargo.toml that contains a version |
| 43 | + Parse the given origin and returns a tuple of path to a |
| 44 | + Cargo.toml that contains a version. |
25 | 45 | Version should be in the table package or workspace.package |
26 | 46 |
|
27 | 47 | If the origin is invalid toml than it will raise a VersionError. |
28 | 48 | """ |
29 | 49 | version: Any = None |
30 | 50 | content = origin.read_text(encoding="utf-8") |
31 | 51 | content = tomlkit.parse(content) |
32 | | - package = content.get("package") |
33 | | - workspace = content.get("workspace") |
34 | | - if not isinstance(package, Table) and isinstance(workspace, Table): |
35 | | - package = workspace.get("package") |
36 | | - if isinstance(package, Table): |
37 | | - version = package.get("version", "") |
| 52 | + if self.__has_workspace_package_version(content): |
| 53 | + version = content.get("workspace").get("package").get("version") # type: ignore[union-attr] |
| 54 | + |
| 55 | + if self.__has_package_version(content): |
| 56 | + version = content.get("package").get("version") # type: ignore[union-attr] |
| 57 | + |
38 | 58 | if version: |
39 | | - yield (origin, content) |
| 59 | + return (origin, content) |
40 | 60 | else: |
41 | | - # check sub directories for toml files with version |
42 | | - if isinstance(workspace, Table): |
43 | | - members = workspace.get("members") |
44 | | - if members: |
45 | | - for member in members: |
46 | | - yield from self.__as_project_document( |
47 | | - origin.parent / member / self.project_file_name |
48 | | - ) |
49 | | - return None |
| 61 | + raise VersionError( |
| 62 | + f"No {origin} file found. This file is required for pontos." |
| 63 | + ) |
50 | 64 |
|
51 | 65 | def update_version( |
52 | 66 | self, new_version: Version, *, force: bool = False |
53 | 67 | ) -> VersionUpdate: |
54 | 68 | try: |
55 | 69 | previous_version = self.get_current_version() |
56 | | - |
57 | 70 | if not force and new_version == previous_version: |
58 | 71 | return VersionUpdate(previous=previous_version, new=new_version) |
59 | 72 | except VersionError: |
60 | 73 | # just ignore current version and override it |
61 | 74 | previous_version = None |
62 | 75 |
|
63 | | - changed_files = [] |
64 | | - for project_path, project in self.__as_project_document( |
| 76 | + project_path, project = self.__as_project_document( |
65 | 77 | self.project_file_path |
66 | | - ): |
| 78 | + ) |
| 79 | + |
| 80 | + if self.__has_workspace_package_version(project): |
| 81 | + # Set the version for all members of the workspace. Members of the |
| 82 | + # workspace should use `version.workspace=true` in the 'package' table, |
| 83 | + # if they are released together with the parent crate. |
| 84 | + project["workspace"]["package"]["version"] = str(new_version) # type: ignore[index] # noqa: E501 |
| 85 | + |
| 86 | + if self.__has_package_version(project): |
| 87 | + # Set the 'version' of the 'package' table for the parent crate |
67 | 88 | project["package"]["version"] = str(new_version) # type: ignore[index] # noqa: E501 |
68 | | - project_path.write_text(tomlkit.dumps(project)) |
69 | | - changed_files.append(project_path) |
| 89 | + |
| 90 | + project_path.write_text(tomlkit.dumps(project)) |
70 | 91 | return VersionUpdate( |
71 | 92 | previous=previous_version, |
72 | 93 | new=new_version, |
73 | | - changed_files=changed_files, |
| 94 | + changed_files=[project_path], |
74 | 95 | ) |
75 | 96 |
|
76 | 97 | def get_current_version(self) -> Version: |
77 | | - (_, document) = next(self.__as_project_document(self.project_file_path)) |
78 | | - try: |
79 | | - version = document["package"]["version"] # type: ignore[index, arg-type] |
80 | | - except KeyError: |
| 98 | + (_, document) = self.__as_project_document(self.project_file_path) |
| 99 | + |
| 100 | + version: Any = None |
| 101 | + if self.__has_workspace_package_version(document): |
81 | 102 | version = document["workspace"]["package"]["version"] # type: ignore[index, arg-type] |
| 103 | + |
| 104 | + if self.__has_package_version(document): |
| 105 | + # If the 'package' table has a 'version', it always has precedence over the |
| 106 | + # 'version' in the 'workspace.package' table (they are assumed to be equal, if |
| 107 | + # managed by pontos) |
| 108 | + version = document["package"]["version"] # type: ignore[index, arg-type] |
| 109 | + |
82 | 110 | if isinstance(version, str): |
83 | 111 | current_version = self.versioning_scheme.parse_version(version) |
| 112 | + |
84 | 113 | return self.versioning_scheme.from_version(current_version) |
85 | 114 |
|
86 | 115 | def verify_version( |
|
0 commit comments