Skip to content

Commit 0c64f9f

Browse files
ezyangfacebook-github-bot
authored andcommitted
Convert from higher order functions to classes in tools.codegen.gen (#47008)
Summary: Pull Request resolved: #47008 bhosmer has been complaining about how it is difficult to distinguish between local variables and closed over variables in the higher order functions. Well, closures and objects do basically the same thing, so just convert all these HOFs into objects. The decoder ring: - Higher order function => Constructor for object - Access to closed over variable => Access to member variable on object - with_native_function => method_with_native_function (because it's hard writing decorators that work for both functions and methods) I didn't even have to change indentation (much). When there is no need for closed over variables (a few functions), I kept them as plain old functions, no need for an object with no members. While I was at it, I also deleted the kwargs, since the types are enough to prevent mistakes. Signed-off-by: Edward Z. Yang <[email protected]> Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D24600805 Pulled By: ezyang fbshipit-source-id: 7e3ce8cb2446e3788f934ddcc17f7da6e9299511
1 parent d478605 commit 0c64f9f

File tree

1 file changed

+86
-68
lines changed

1 file changed

+86
-68
lines changed

tools/codegen/gen.py

Lines changed: 86 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import pathlib
1111
import functools
1212
import json
13+
from dataclasses import dataclass
1314

1415
from tools.codegen.code_template import CodeTemplate
1516
from tools.codegen.model import *
@@ -102,13 +103,25 @@ def parse_native_yaml(path: str) -> List[NativeFunction]:
102103
def with_native_function(func: Callable[[NativeFunction], T]) -> Callable[[NativeFunction], T]:
103104
@functools.wraps(func)
104105
def wrapper(f: NativeFunction) -> T:
105-
with context(f'in {f.loc}:\n {f.func}'):
106-
with local.parametrize(
107-
use_c10_dispatcher=f.use_c10_dispatcher,
108-
):
109-
return func(f)
106+
with native_function_manager(f):
107+
return func(f)
110108
return wrapper
111109

110+
def method_with_native_function(func: Callable[[S, NativeFunction], T]) -> Callable[[S, NativeFunction], T]:
111+
@functools.wraps(func)
112+
def wrapper(slf: S, f: NativeFunction) -> T:
113+
with native_function_manager(f):
114+
return func(slf, f)
115+
return wrapper
116+
117+
@contextlib.contextmanager
118+
def native_function_manager(f: NativeFunction) -> Iterator[None]:
119+
with context(f'in {f.loc}:\n {f.func}'):
120+
with local.parametrize(
121+
use_c10_dispatcher=f.use_c10_dispatcher,
122+
):
123+
yield
124+
112125
# These two functions purposely return generators in analogy to map()
113126
# so that you don't mix up when you need to list() them
114127

@@ -180,49 +193,53 @@ def cpp_string(s: str) -> str:
180193
#
181194
# This function is also used for a secondary purpose: the registration
182195
# logic is also reused to implement per-operator registration.
183-
def compute_type_method(
184-
dispatch: Optional[str], *,
196+
@dataclass(frozen=True)
197+
class ComputeTypeMethod:
198+
dispatch: Optional[str]
199+
185200
# TODO: Give more precise type Union[Literal[Target.DEFINITION,
186201
# Target.REGISTRATION]]; requires Literal from typing_extensions
187202
# which we don't have a dep for yet.
188-
target: Target,
203+
target: Target
204+
189205
# Selector object to determine which operators to generate
190206
# registration code for.
191207
selector: SelectiveBuilder
192-
) -> Callable[[NativeFunction], Optional[str]]:
193208

194-
if dispatch is None:
195-
assert target is Target.REGISTRATION
209+
def __post_init__(self) -> None:
210+
assert self.target is not Target.DECLARATION
211+
if self.dispatch is None:
212+
assert self.target is Target.REGISTRATION
196213

197-
@with_native_function
198-
def func(f: NativeFunction) -> Optional[str]:
199-
# Has to be here as mypy won't transfer asserts into closures
200-
assert target is not Target.DECLARATION
214+
@method_with_native_function
215+
def __call__(self, f: NativeFunction) -> Optional[str]:
216+
# for mypy type refinement; would be fixed by TODO on target
217+
assert self.target is not Target.DECLARATION
201218

202-
if dispatch is not None:
203-
if dispatch not in f.dispatch:
219+
if self.dispatch is not None:
220+
if self.dispatch not in f.dispatch:
204221
return None
205222

206223
op_name = f"aten::{f.func.name}"
207-
if target is Target.REGISTRATION and not selector.is_operator_selected(op_name):
224+
if self.target is Target.REGISTRATION and not self.selector.is_operator_selected(op_name):
208225
return None
209226

210227
name = native.name(f.func)
211228
returns_type = native.returns_type(f.func.returns)
212229
args = native.arguments(f.func)
213230
args_str = ', '.join(map(str, args))
214-
dispatch_to_all_backends = dispatch is not None and dispatch in KEYWORD_ALL_BACKENDS
231+
dispatch_to_all_backends = self.dispatch is not None and self.dispatch in KEYWORD_ALL_BACKENDS
215232

216-
if target is Target.DEFINITION:
217-
assert dispatch is not None
218-
impl_name = f"at::native::{f.dispatch[dispatch]}"
233+
if self.target is Target.DEFINITION:
234+
assert self.dispatch is not None
235+
impl_name = f"at::native::{f.dispatch[self.dispatch]}"
219236

220237
args_exprs_str = ', '.join(a.name for a in args)
221238

222239
return_kw = " return "
223240

224241
cuda_guard = ""
225-
if dispatch_to_all_backends or 'CUDA' in dispatch:
242+
if dispatch_to_all_backends or 'CUDA' in self.dispatch:
226243
self_args = (a for a in f.func.arguments if a.name == "self")
227244

228245
# There is precedence for which argument we use to do
@@ -249,7 +266,7 @@ def func(f: NativeFunction) -> Optional[str]:
249266
# works just as well.
250267
if f.device_guard and dispatch_to_all_backends and has_tensor_options:
251268
cuda_guard = cuda_guard_from_tensor_options
252-
elif f.device_guard and dispatch is not None and 'CUDA' in dispatch and has_tensor_options:
269+
elif f.device_guard and self.dispatch is not None and 'CUDA' in self.dispatch and has_tensor_options:
253270
cuda_guard = f"""\
254271
globalContext().lazyInitCUDA();
255272
{cuda_guard_from_tensor_options}
@@ -269,16 +286,16 @@ def func(f: NativeFunction) -> Optional[str]:
269286
}}
270287
"""
271288

272-
elif target is Target.REGISTRATION:
273-
if dispatch is None:
289+
elif self.target is Target.REGISTRATION:
290+
if self.dispatch is None:
274291
return f'm.def({cpp_string(str(f.func))});\n'
275292
elif f.manual_kernel_registration:
276293
return None
277294
else:
278295
if dispatch_to_all_backends:
279296
type_name = f'TypeDefault::{name}'
280297
else:
281-
type_name = f'{dispatch}Type::{name}'
298+
type_name = f'{self.dispatch}Type::{name}'
282299

283300
dispatcher_sig = DispatcherSignature.from_schema(f.func)
284301

@@ -302,21 +319,22 @@ def func(f: NativeFunction) -> Optional[str]:
302319
# in a TORCH_LIBRARY_FRAGMENT that does not have an ambient backend. So
303320
# the torch::dispatch specification here is important! See
304321
# Note [Redundancy in registration code is OK] for how we handle redundant info.
305-
if dispatch is not None:
306-
payload = f"torch::dispatch(DispatchKey::{dispatch},\n{payload})\n"
322+
if self.dispatch is not None:
323+
payload = f"torch::dispatch(DispatchKey::{self.dispatch},\n{payload})\n"
307324

308325
return f'm.impl("{f.func.name}",\n{payload});\n'
309326
else:
310-
assert_never(target)
311-
312-
return func
327+
assert_never(self.target)
313328

314329
# Generates Function.cpp and Function.h. These files provide the
315330
# functional public C++ API, and the scaffolding to call into
316331
# the dispatcher from these functions. See also compute_tensor_method.
317-
def compute_function(*, target: Target) -> Callable[[NativeFunction], Optional[str]]:
318-
@with_native_function
319-
def go(f: NativeFunction) -> Optional[str]:
332+
@dataclass(frozen=True)
333+
class ComputeFunction:
334+
target: Target
335+
336+
@method_with_native_function
337+
def __call__(self, f: NativeFunction) -> Optional[str]:
320338
if f.manual_kernel_registration:
321339
return None
322340
if Variant.function not in f.variants:
@@ -326,13 +344,13 @@ def go(f: NativeFunction) -> Optional[str]:
326344

327345
sig_group = CppSignatureGroup.from_schema(f.func, method=False)
328346

329-
if target is Target.DECLARATION:
347+
if self.target is Target.DECLARATION:
330348
result = f"CAFFE2_API {sig_group.signature.decl()};\n"
331349
if sig_group.faithful_signature is not None:
332350
result += f"CAFFE2_API {sig_group.faithful_signature.decl()};\n"
333351
return result
334352

335-
assert target is Target.DEFINITION
353+
assert self.target is Target.DEFINITION
336354

337355
def generate_defn(sig: CppSignature) -> str:
338356
dispatcher_sig = DispatcherSignature.from_schema(f.func)
@@ -357,14 +375,15 @@ def generate_defn(sig: CppSignature) -> str:
357375

358376
return result
359377

360-
return go
361-
362378
# Generates TensorBody.h (sic) and TensorMethods.cpp. These files provide the
363379
# object-oriented (method-based) public C++ API, and the scaffolding to call into
364380
# the dispatcher from these functions. See also compute_function.
365-
def compute_tensor_method(*, target: Target) -> Callable[[NativeFunction], Optional[str]]:
366-
@with_native_function
367-
def go(f: NativeFunction) -> Optional[str]:
381+
@dataclass(frozen=True)
382+
class ComputeTensorMethod:
383+
target: Target
384+
385+
@method_with_native_function
386+
def __call__(self, f: NativeFunction) -> Optional[str]:
368387
if Variant.method not in f.variants:
369388
return None
370389

@@ -376,13 +395,13 @@ def go(f: NativeFunction) -> Optional[str]:
376395

377396
sig_group = CppSignatureGroup.from_schema(f.func, method=True)
378397

379-
if target is Target.DECLARATION:
398+
if self.target is Target.DECLARATION:
380399
result = f"{sig_group.signature.decl()} const;\n"
381400
if sig_group.faithful_signature is not None:
382401
result += f"{sig_group.faithful_signature.decl()} const;\n"
383402
return result
384403

385-
assert target is Target.DEFINITION
404+
assert self.target is Target.DEFINITION
386405

387406
def generate_defn(sig: CppSignature) -> str:
388407
dispatcher_sig = DispatcherSignature.from_schema(f.func)
@@ -406,8 +425,6 @@ def generate_defn(sig: CppSignature) -> str:
406425

407426
return result
408427

409-
return go
410-
411428
# Generates ATenOpList.cpp, a runtime accessible list of all aten
412429
# operators.
413430
# TODO: This was historically used to help some JIT interop code
@@ -442,9 +459,12 @@ def compute_native_function_declaration(f: NativeFunction) -> List[str]:
442459
# Generates BackendSelectRegister.cpp, a series of kernels which provide
443460
# specialized computation of dispatch key for operator signatures which cannot
444461
# be easily done automatically using templating.
445-
def compute_backend_select(*, target: Target) -> Callable[[NativeFunction], Optional[str]]:
446-
@with_native_function
447-
def go(f: NativeFunction) -> Optional[str]:
462+
@dataclass(frozen=True)
463+
class ComputeBackendSelect:
464+
target: Target
465+
466+
@method_with_native_function
467+
def __call__(self, f: NativeFunction) -> Optional[str]:
448468
if str(f.func.name.name).endswith('_like') or str(f.func.name.name).startswith('new_'):
449469
return None
450470

@@ -471,7 +491,7 @@ def go(f: NativeFunction) -> Optional[str]:
471491
dispatcher_exprs = native_sig.dispatcher_exprs()
472492
dispatch_key = "options.computeDispatchKey()"
473493

474-
if target is Target.DEFINITION:
494+
if self.target is Target.DEFINITION:
475495
# I don't think there's actually a good reason to generate
476496
# these two cases differently
477497
# The first case could probably be improved though- it calls dispatchTypeId(),
@@ -494,7 +514,7 @@ def go(f: NativeFunction) -> Optional[str]:
494514
return op.callWithDispatchKey(_dk, {', '.join(a.expr for a in dispatcher_exprs)});
495515
}}
496516
"""
497-
elif target is Target.REGISTRATION:
517+
elif self.target is Target.REGISTRATION:
498518
if local.use_c10_dispatcher() is UseC10Dispatcher.full:
499519
return f"""m.impl("aten::{f.func.name}", TORCH_FN({name}));"""
500520
elif local.use_c10_dispatcher() is UseC10Dispatcher.hacky_wrapper_for_legacy_signatures:
@@ -504,11 +524,10 @@ def go(f: NativeFunction) -> Optional[str]:
504524
else:
505525
assert local.use_c10_dispatcher() is UseC10Dispatcher.with_codegenerated_unboxing_wrapper
506526
return f"""m.impl_UNBOXED("aten::{f.func.name}", {name});"""
507-
elif target is Target.DECLARATION:
527+
elif self.target is Target.DECLARATION:
508528
raise AssertionError()
509529
else:
510-
assert_never(target)
511-
return go
530+
assert_never(self.target)
512531

513532
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
514533
#
@@ -993,12 +1012,11 @@ def make_file_manager(install_dir: str) -> FileManager:
9931012
'',
9941013
'Backend': dispatch,
9951014
'type_derived_method_definitions': list(mapMaybe(
996-
compute_type_method(dispatch, target=Target.DEFINITION, selector=selector),
1015+
ComputeTypeMethod(dispatch, Target.DEFINITION, selector),
9971016
native_functions
9981017
)),
9991018
'function_registrations': list(mapMaybe(
1000-
compute_type_method(
1001-
dispatch, target=Target.REGISTRATION, selector=selector),
1019+
ComputeTypeMethod(dispatch, Target.REGISTRATION, selector),
10021020
native_functions
10031021
)),
10041022
})
@@ -1012,35 +1030,35 @@ def make_file_manager(install_dir: str) -> FileManager:
10121030
cpu_fm.write('TypeDefault.cpp', lambda: {
10131031
'type_method_definitions':
10141032
list(mapMaybe(
1015-
compute_type_method('Math', target=Target.DEFINITION, selector=selector),
1033+
ComputeTypeMethod('Math', Target.DEFINITION, selector),
10161034
native_functions)) +
10171035
list(mapMaybe(
1018-
compute_type_method('DefaultBackend', target=Target.DEFINITION, selector=selector),
1036+
ComputeTypeMethod('DefaultBackend', Target.DEFINITION, selector),
10191037
native_functions)),
10201038

10211039
'function_registrations': list(mapMaybe(
1022-
compute_type_method(None, target=Target.REGISTRATION, selector=schema_selector),
1040+
ComputeTypeMethod(None, Target.REGISTRATION, schema_selector),
10231041
native_functions)),
10241042

10251043
'math_function_registrations': list(mapMaybe(
1026-
compute_type_method('Math', target=Target.REGISTRATION, selector=selector),
1044+
ComputeTypeMethod('Math', Target.REGISTRATION, selector),
10271045
native_functions)),
10281046

10291047
'default_backend_function_registrations': list(mapMaybe(
1030-
compute_type_method('DefaultBackend', target=Target.REGISTRATION, selector=selector),
1048+
ComputeTypeMethod('DefaultBackend', Target.REGISTRATION, selector),
10311049
native_functions)),
10321050
})
10331051
cpu_fm.write('Functions.h', lambda: {
1034-
'function_declarations': list(mapMaybe(compute_function(target=Target.DECLARATION), native_functions)),
1052+
'function_declarations': list(mapMaybe(ComputeFunction(Target.DECLARATION), native_functions)),
10351053
})
10361054
cpu_fm.write('Functions.cpp', lambda: {
1037-
'function_definitions': list(mapMaybe(compute_function(target=Target.DEFINITION), native_functions)),
1055+
'function_definitions': list(mapMaybe(ComputeFunction(Target.DEFINITION), native_functions)),
10381056
})
10391057
core_fm.write('TensorBody.h', lambda: {
1040-
'tensor_method_declarations': list(mapMaybe(compute_tensor_method(target=Target.DECLARATION), native_functions)),
1058+
'tensor_method_declarations': list(mapMaybe(ComputeTensorMethod(Target.DECLARATION), native_functions)),
10411059
})
10421060
core_fm.write('TensorMethods.cpp', lambda: {
1043-
'tensor_method_definitions': list(mapMaybe(compute_tensor_method(target=Target.DEFINITION), native_functions)),
1061+
'tensor_method_definitions': list(mapMaybe(ComputeTensorMethod(Target.DEFINITION), native_functions)),
10441062
})
10451063
core_fm.write('ATenOpList.cpp', lambda: {
10461064
'aten_ops': list(mapMaybe(compute_aten_op, native_functions)),
@@ -1050,9 +1068,9 @@ def make_file_manager(install_dir: str) -> FileManager:
10501068
})
10511069
cpu_fm.write('BackendSelectRegister.cpp', lambda: {
10521070
'backend_select_method_definitions':
1053-
list(mapMaybe(compute_backend_select(target=Target.DEFINITION), native_functions)),
1071+
list(mapMaybe(ComputeBackendSelect(Target.DEFINITION), native_functions)),
10541072
'backend_select_function_registrations':
1055-
list(mapMaybe(compute_backend_select(target=Target.REGISTRATION), native_functions)),
1073+
list(mapMaybe(ComputeBackendSelect(Target.REGISTRATION), native_functions)),
10561074
})
10571075

10581076
cpu_fm.write('Declarations.yaml', lambda: format_yaml([compute_declaration_yaml(f) for f in native_functions]))

0 commit comments

Comments
 (0)