Skip to content

Commit 267641a

Browse files
ezyangfacebook-github-bot
authored andcommitted
Rename positional and kwarg_only to have flat prefix (#49042)
Summary: Pull Request resolved: #49042 I want the names positional and kwarg_only to give the unflat representation (e.g., preserving TensorOptionsArguments in the returned Union). So I regret my original naming choice when I moved grouping to model. This renames them to have flat_ prefix and also adds a flat_non_out argument for cases where you just want to look at non-out arguments. Signed-off-by: Edward Z. Yang <[email protected]> Test Plan: Imported from OSS Reviewed By: H-Huang Differential Revision: D25455884 Pulled By: ezyang fbshipit-source-id: f923f8881267a3e3e8e9521519412f7cc25034fc
1 parent 0dea76e commit 267641a

File tree

7 files changed

+37
-23
lines changed

7 files changed

+37
-23
lines changed

tools/autograd/gen_annotated_fn_args.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def gen_annotated(native_yaml_path: str, out: str, autograd_dir: str) -> None:
5252
@with_native_function
5353
def gen_annotated_args(f: NativeFunction) -> str:
5454
out_args: List[Dict[str, Any]] = []
55-
for arg in f.func.arguments.positional:
55+
for arg in f.func.arguments.flat_positional:
5656
if arg.default is not None:
5757
continue
5858
out_arg: Dict[str, Any] = {}

tools/autograd/gen_trace_type.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,7 @@ def dispatch_trace_input(arg: Union[Argument, TensorOptionsArguments]) -> Sequen
150150
# Factories are a bit special because their out-of-place overloads
151151
# take an extra TensorOptions argument, which is missing in the _out function
152152
has_tensor_return = any(r.type.is_tensor_like() for r in f.func.returns)
153-
has_tensor_input_arg = any(a.type.is_tensor_like()
154-
for a in itertools.chain(f.func.arguments.positional, f.func.arguments.kwarg_only))
153+
has_tensor_input_arg = any(a.type.is_tensor_like() for a in f.func.arguments.flat_non_out)
155154
is_factory_method = f.category_override == 'factory' or (has_tensor_return and not has_tensor_input_arg)
156155

157156
# HACK: preserve old codegen behavior - the old codegen set the `is_factory_method`

tools/codegen/api/dispatcher.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,11 @@ def name(func: FunctionSchema) -> str:
6868

6969
def arguments(func: FunctionSchema) -> Tuple[DispatcherArgument, ...]:
7070
if local.use_c10_dispatcher().dispatcher_uses_new_style():
71-
return tuple(map(argument, itertools.chain(func.arguments.positional, func.arguments.kwarg_only, func.arguments.out)))
71+
return tuple(map(argument, itertools.chain(
72+
func.arguments.flat_positional,
73+
func.arguments.flat_kwarg_only,
74+
func.arguments.out
75+
)))
7276
else:
7377
return tuple(
7478
DispatcherArgument(type=la.type, name=la.name, argument=la.argument)

tools/codegen/api/meta.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import tools.codegen.api.dispatcher as dispatcher
55

66
from typing import Sequence
7-
import itertools
87

98
# Follows dispatcher calling convention, but:
109
# - Mutable arguments not allowed. Meta functions are always
@@ -29,4 +28,4 @@ def argument(a: Argument) -> MetaArgument:
2928

3029
def arguments(func: FunctionSchema) -> Sequence[MetaArgument]:
3130
assert not func.arguments.out
32-
return list(map(argument, itertools.chain(func.arguments.positional, func.arguments.kwarg_only)))
31+
return list(map(argument, func.arguments.flat_non_out))

tools/codegen/api/python.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import itertools
21
from dataclasses import dataclass
32
from typing import Optional, Union, Sequence, Set, List, Tuple, Dict
43

@@ -734,8 +733,8 @@ def signature(f: NativeFunction, *, method: bool = False, pyi: bool = False) ->
734733
cpp_args = cpp.group_arguments(f.func, method=method, faithful=True)
735734
args = tuple(a for a in cpp_args if isinstance(a, Argument))
736735

737-
input_arg_set = set(a.name for a in f.func.arguments.positional)
738-
kwarg_only_set = set(a.name for a in f.func.arguments.kwarg_only)
736+
input_arg_set = set(a.name for a in f.func.arguments.flat_positional)
737+
kwarg_only_set = set(a.name for a in f.func.arguments.flat_kwarg_only)
739738
out_arg_set = set(a.name for a in f.func.arguments.out)
740739

741740
input_args = tuple(map(argument, filter(lambda a: a.name in input_arg_set, args)))
@@ -750,8 +749,7 @@ def signature(f: NativeFunction, *, method: bool = False, pyi: bool = False) ->
750749
# to the original versions in the yaml, this recreation is a potential
751750
# source of drift between eager and JIT. Pull this logic out to a shared place.
752751

753-
has_tensor_input_arg = any(a.type.is_tensor_like()
754-
for a in itertools.chain(f.func.arguments.positional, f.func.arguments.kwarg_only))
752+
has_tensor_input_arg = any(a.type.is_tensor_like() for a in f.func.arguments.flat_non_out)
755753
if any(a.name == 'requires_grad' for a in f.func.schema_order_arguments()):
756754
raise ValueError('argument named requires_grad is reserved, should not explicitly add it in the schema')
757755

tools/codegen/gen.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -486,11 +486,15 @@ def gen_unstructured(self, f: NativeFunction) -> Optional[str]:
486486

487487
cuda_guard = ""
488488
if is_generic_dispatch_key(self.dispatch_key) or is_cuda_dispatch_key(self.dispatch_key):
489-
self_args = (a for a in f.func.arguments.positional if a.name == "self")
489+
self_arg = [f.func.arguments.self_arg.argument] if f.func.arguments.self_arg is not None else []
490490

491491
# There is precedence for which argument we use to do
492492
# device guard. This describes the precedence order.
493-
candidate_args = itertools.chain(self_args, f.func.arguments.out, f.func.arguments.positional)
493+
candidate_args = itertools.chain(
494+
self_arg,
495+
f.func.arguments.out,
496+
f.func.arguments.flat_positional
497+
)
494498

495499
# Only tensor like arguments are eligible
496500
device_of = next((f'{a.name}' for a in candidate_args if a.type.is_tensor_like()), None)
@@ -619,8 +623,7 @@ def __call__(self, f: NativeFunction) -> Optional[str]:
619623
return None
620624

621625
assert not f.func.is_out_fn()
622-
assert len(f.func.arguments.positional) > 0
623-
assert sum(a.name == 'self' for a in f.func.arguments.positional) == 1
626+
assert f.func.arguments.self_arg is not None
624627

625628
name = cpp.name(f.func)
626629

@@ -992,7 +995,7 @@ def compute_declaration_yaml(f: NativeFunction) -> object:
992995

993996
# These sets are used to conveniently test if an argument is a
994997
# kwarg-only or out argument
995-
kwarg_only_set = set(a.name for a in f.func.arguments.kwarg_only)
998+
kwarg_only_set = set(a.name for a in f.func.arguments.flat_kwarg_only)
996999
out_arg_set = set(a.name for a in f.func.arguments.out)
9971000

9981001
sig_group = CppSignatureGroup.from_schema(f.func, method=False)

tools/codegen/model.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,11 @@ class FunctionSchema:
398398
returns: Tuple['Return', ...]
399399

400400
def schema_order_arguments(self) -> Iterator['Argument']:
401-
return itertools.chain(self.arguments.positional, self.arguments.kwarg_only, self.arguments.out)
401+
return itertools.chain(
402+
self.arguments.flat_positional,
403+
self.arguments.flat_kwarg_only,
404+
self.arguments.out
405+
)
402406

403407
@staticmethod
404408
def parse(func: str) -> 'FunctionSchema':
@@ -428,7 +432,7 @@ def __post_init__(self) -> None:
428432
# This means that all mutable returns should be aliased to a keyword argument
429433
# (except for "self", which we explicitly don't treat as an out argument because of its use in methods)
430434
# See Note [is_out_fn]
431-
out_and_self = list(self.arguments.out) + [arg for arg in self.arguments.positional if arg.name == "self"]
435+
out_and_self = list(self.arguments.out) + [arg for arg in self.arguments.flat_positional if arg.name == "self"]
432436
mutable_returns = [ret for ret in self.returns if ret.annotation is not None and ret.annotation.is_write]
433437
for ret in mutable_returns:
434438
assert any([ret.annotation == arg.annotation for arg in out_and_self]), \
@@ -899,7 +903,14 @@ class Arguments:
899903
out: Tuple[Argument, ...] # these are also kwarg-only
900904

901905
@property
902-
def positional(self) -> Sequence[Argument]:
906+
def flat_non_out(self) -> Sequence[Argument]:
907+
ret: List[Argument] = []
908+
ret.extend(self.flat_positional)
909+
ret.extend(self.flat_kwarg_only)
910+
return ret
911+
912+
@property
913+
def flat_positional(self) -> Sequence[Argument]:
903914
ret: List[Argument] = []
904915
ret.extend(self.pre_self_positional)
905916
if self.self_arg is not None:
@@ -909,7 +920,7 @@ def positional(self) -> Sequence[Argument]:
909920

910921
# NB: doesn't contain out arguments
911922
@property
912-
def kwarg_only(self) -> Sequence[Argument]:
923+
def flat_kwarg_only(self) -> Sequence[Argument]:
913924
ret: List[Argument] = []
914925
ret.extend(self.pre_tensor_options_kwarg_only)
915926
if self.tensor_options is not None:
@@ -1056,10 +1067,10 @@ def pred(name: str, ty: Type) -> Callable[[Argument], bool]:
10561067

10571068
def __str__(self) -> str:
10581069
all_arguments: List[str] = []
1059-
all_arguments.extend(map(str, self.positional))
1060-
if self.kwarg_only or self.out:
1070+
all_arguments.extend(map(str, self.flat_positional))
1071+
if self.flat_kwarg_only or self.out:
10611072
all_arguments.append('*')
1062-
all_arguments.extend(map(str, self.kwarg_only))
1073+
all_arguments.extend(map(str, self.flat_kwarg_only))
10631074
all_arguments.extend(map(str, self.out))
10641075
return ', '.join(all_arguments)
10651076

0 commit comments

Comments
 (0)