Skip to content

Use tuple instead a comma-separated string in pytest.mark.parametrize for better readability #6604

@harupy

Description

@harupy

Example

diff --git a/tests/utils/test_string_utils.py b/tests/utils/test_string_utils.py
index 47209eebb..ef933c1a3 100644
--- a/tests/utils/test_string_utils.py
+++ b/tests/utils/test_string_utils.py
@@ -4,7 +4,7 @@ from mlflow.utils.string_utils import strip_prefix, strip_suffix, is_string_type
 
 
 @pytest.mark.parametrize(
-    "original,prefix,expected",
+    ("original", "prefix", "expected"),
     [("smoketest", "smoke", "test"), ("", "test", ""), ("", "", ""), ("test", "", "test")],
 )
 def test_strip_prefix(original, prefix, expected):
@@ -12,7 +12,7 @@ def test_strip_prefix(original, prefix, expected):
 
 
 @pytest.mark.parametrize(
-    "original,suffix,expected",
+    ("original", "suffix", "expected"),
     [("smoketest", "test", "smoke"), ("", "test", ""), ("", "", ""), ("test", "", "test")],
 )
 def test_strip_suffix(original, suffix, expected):

Instructions

  1. Save the following code as a.py and run it in the repository root.
  2. Remove a.py
  3. Run black tests.
  4. File a PR.
import ast
import subprocess
import re
from typing import Optional, List
from pathlib import Path


def get_qualname(node: ast.AST) -> Optional[str]:
    parts = []
    while True:
        if isinstance(node, ast.Name):
            parts.append(node.id)
            break
        if isinstance(node, ast.Attribute):
            parts.append(node.attr)
            node = node.value
        else:
            return None
    return ".".join(reversed(parts))


class Replacer(ast.NodeVisitor):
    def __init__(self, lines: List[str]) -> None:
        self.lines = lines

    def visit_Call(self, node: ast.Call) -> None:
        if (get_qualname(node.func)) == "pytest.mark.parametrize":
            first_arg = node.args[0]
            if isinstance(first_arg, ast.Str) and "," in first_arg.s:
                argnames = first_arg.s
                row = first_arg.lineno - 1
                col = first_arg.col_offset
                head = self.lines[row][:col]  # @pytest.mark.parametrize(
                tail = self.lines[row][col:]  # "a,b", [(1, 2)])
                tup = str(tuple(re.split(r"\s*,\s*", argnames)))
                self.lines[row] = head + tail.replace(f'"{argnames}"', tup, 1)
        self.generic_visit(node)


def is_python_file(path: str) -> bool:
    return path.endswith(".py")


git_ls_files_out = subprocess.check_output(
    ["git", "ls-files", "--directory", "tests"],
    text=True,
)

for f in map(Path, filter(is_python_file, git_ls_files_out.splitlines())):
    print("Processing", f)
    src = f.read_text()
    replacer = Replacer(lines=src.split("\n"))
    node = ast.parse(src)
    replacer.visit(node)
    f.write_text("\n".join(replacer.lines))

References

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions