11use ruff_macros:: { ViolationMetadata , derive_message_formats} ;
22use ruff_python_ast:: helpers:: ReturnStatementVisitor ;
33use ruff_python_ast:: visitor:: Visitor ;
4- use ruff_python_ast:: { self as ast, Expr , StmtFunctionDef } ;
5- use ruff_python_semantic:: Modules ;
4+ use ruff_python_ast:: { self as ast, Expr , ExprCall , Stmt , StmtFunctionDef } ;
5+ use ruff_python_semantic:: { BindingKind , Modules , ScopeKind } ;
66use ruff_text_size:: Ranged ;
77
88use crate :: Violation ;
99use crate :: checkers:: ast:: Checker ;
1010use crate :: rules:: airflow:: helpers:: is_airflow_task_variant;
1111
1212/// ## What it does
13- /// Checks for `@task.branch` decorated functions that could be replaced
14- /// with `@task.short_circuit`.
13+ /// Checks for branching logic that could be replaced with a short-circuit
14+ /// pattern, either via `@task.branch` decorated functions or
15+ /// `BranchPythonOperator` callables.
1516///
1617/// ## Why is this bad?
17- /// When a `@task. branch` function has at least two `return` statements and
18- /// exactly one of them returns a non-empty list, the function is effectively
19- /// acting as a short-circuit operator. Using `@task.short_circuit` is
20- /// simpler and more readable in such cases.
18+ /// When a branch function has at least two `return` statements and exactly
19+ /// one of them returns a non-empty list, the function is effectively acting
20+ /// as a short-circuit operator. Using `@task.short_circuit` or
21+ /// `ShortCircuitOperator` is simpler and more readable in such cases.
2122///
2223/// ## Example
24+ ///
25+ /// Using the `TaskFlow` API:
2326/// ```python
2427/// from airflow.decorators import task
2528///
@@ -40,18 +43,60 @@ use crate::rules::airflow::helpers::is_airflow_task_variant;
4043/// def my_task():
4144/// return condition
4245/// ```
46+ ///
47+ /// Using the standard operator API:
48+ /// ```python
49+ /// from airflow.operators.python import BranchPythonOperator
50+ ///
51+ ///
52+ /// def my_callable():
53+ /// if condition:
54+ /// return ["my_downstream_task"]
55+ /// return []
56+ ///
57+ ///
58+ /// task = BranchPythonOperator(task_id="my_task", python_callable=my_callable)
59+ /// ```
60+ ///
61+ /// Use instead:
62+ /// ```python
63+ /// from airflow.operators.python import ShortCircuitOperator
64+ ///
65+ ///
66+ /// def my_callable():
67+ /// return condition
68+ ///
69+ ///
70+ /// task = ShortCircuitOperator(task_id="my_task", python_callable=my_callable)
71+ /// ```
4372#[ derive( ViolationMetadata ) ]
4473#[ violation_metadata( preview_since = "NEXT_RUFF_VERSION" ) ]
45- pub ( crate ) struct TaskBranchAsShortCircuit ;
74+ pub ( crate ) struct TaskBranchAsShortCircuit {
75+ kind : BranchKind ,
76+ }
77+
78+ #[ derive( Debug , Clone , Copy , PartialEq , Eq ) ]
79+ enum BranchKind {
80+ Decorator ,
81+ Operator ,
82+ }
4683
4784impl Violation for TaskBranchAsShortCircuit {
4885 #[ derive_message_formats]
4986 fn message ( & self ) -> String {
50- "A `@task.branch` that can be replaced with `@task.short_circuit`" . to_string ( )
87+ match self . kind {
88+ BranchKind :: Decorator => {
89+ "A `@task.branch` that can be replaced with `@task.short_circuit`" . to_string ( )
90+ }
91+ BranchKind :: Operator => {
92+ "A `BranchPythonOperator` that can be replaced with `ShortCircuitOperator`"
93+ . to_string ( )
94+ }
95+ }
5196 }
5297}
5398
54- /// AIR003
99+ /// AIR003 (decorator form)
55100pub ( crate ) fn task_branch_as_short_circuit ( checker : & Checker , function_def : & StmtFunctionDef ) {
56101 if !checker. semantic ( ) . seen_module ( Modules :: AIRFLOW ) {
57102 return ;
@@ -61,14 +106,88 @@ pub(crate) fn task_branch_as_short_circuit(checker: &Checker, function_def: &Stm
61106 return ;
62107 }
63108
109+ if could_be_short_circuit ( & function_def. body ) {
110+ checker. report_diagnostic (
111+ TaskBranchAsShortCircuit {
112+ kind : BranchKind :: Decorator ,
113+ } ,
114+ function_def. range ( ) ,
115+ ) ;
116+ }
117+ }
118+
119+ /// AIR003 (operator form)
120+ pub ( crate ) fn branch_python_operator_as_short_circuit ( checker : & Checker , call : & ExprCall ) {
121+ if !checker. semantic ( ) . seen_module ( Modules :: AIRFLOW ) {
122+ return ;
123+ }
124+
125+ let semantic = checker. semantic ( ) ;
126+
127+ let Some ( qualified_name) = semantic. resolve_qualified_name ( & call. func ) else {
128+ return ;
129+ } ;
130+
131+ if !matches ! (
132+ qualified_name. segments( ) ,
133+ [
134+ "airflow" ,
135+ "operators" ,
136+ "python" | "python_operator" ,
137+ "BranchPythonOperator"
138+ ] | [
139+ "airflow" ,
140+ "providers" ,
141+ "standard" ,
142+ "operators" ,
143+ "python" ,
144+ "BranchPythonOperator"
145+ ]
146+ ) {
147+ return ;
148+ }
149+
150+ let Some ( keyword) = call. arguments . find_keyword ( "python_callable" ) else {
151+ return ;
152+ } ;
153+
154+ let Expr :: Name ( name_expr) = & keyword. value else {
155+ return ;
156+ } ;
157+
158+ let Some ( binding_id) = semantic. only_binding ( name_expr) else {
159+ return ;
160+ } ;
161+
162+ let BindingKind :: FunctionDefinition ( scope_id) = semantic. binding ( binding_id) . kind else {
163+ return ;
164+ } ;
165+
166+ let ScopeKind :: Function ( function_def) = semantic. scopes [ scope_id] . kind else {
167+ return ;
168+ } ;
169+
170+ if could_be_short_circuit ( & function_def. body ) {
171+ checker. report_diagnostic (
172+ TaskBranchAsShortCircuit {
173+ kind : BranchKind :: Operator ,
174+ } ,
175+ call. func . range ( ) ,
176+ ) ;
177+ }
178+ }
179+
180+ /// Returns `true` if the function body has 2+ return statements with exactly
181+ /// one non-empty list return — indicating a short-circuit pattern.
182+ fn could_be_short_circuit ( body : & [ Stmt ] ) -> bool {
64183 let mut visitor = ReturnStatementVisitor :: default ( ) ;
65- for stmt in & function_def . body {
184+ for stmt in body {
66185 visitor. visit_stmt ( stmt) ;
67186 }
68187
69188 let returns = & visitor. returns ;
70189 if returns. len ( ) < 2 {
71- return ;
190+ return false ;
72191 }
73192
74193 let non_empty_list_count = returns
@@ -80,7 +199,5 @@ pub(crate) fn task_branch_as_short_circuit(checker: &Checker, function_def: &Stm
80199 } )
81200 . count ( ) ;
82201
83- if non_empty_list_count == 1 {
84- checker. report_diagnostic ( TaskBranchAsShortCircuit , function_def. range ( ) ) ;
85- }
202+ non_empty_list_count == 1
86203}
0 commit comments