1010import pathlib
1111import functools
1212import json
13+ from dataclasses import dataclass
1314
1415from tools .codegen .code_template import CodeTemplate
1516from tools .codegen .model import *
@@ -102,13 +103,25 @@ def parse_native_yaml(path: str) -> List[NativeFunction]:
102103def with_native_function (func : Callable [[NativeFunction ], T ]) -> Callable [[NativeFunction ], T ]:
103104 @functools .wraps (func )
104105 def wrapper (f : NativeFunction ) -> T :
105- with context (f'in { f .loc } :\n { f .func } ' ):
106- with local .parametrize (
107- use_c10_dispatcher = f .use_c10_dispatcher ,
108- ):
109- return func (f )
106+ with native_function_manager (f ):
107+ return func (f )
110108 return wrapper
111109
110+ def method_with_native_function (func : Callable [[S , NativeFunction ], T ]) -> Callable [[S , NativeFunction ], T ]:
111+ @functools .wraps (func )
112+ def wrapper (slf : S , f : NativeFunction ) -> T :
113+ with native_function_manager (f ):
114+ return func (slf , f )
115+ return wrapper
116+
117+ @contextlib .contextmanager
118+ def native_function_manager (f : NativeFunction ) -> Iterator [None ]:
119+ with context (f'in { f .loc } :\n { f .func } ' ):
120+ with local .parametrize (
121+ use_c10_dispatcher = f .use_c10_dispatcher ,
122+ ):
123+ yield
124+
112125# These two functions purposely return generators in analogy to map()
113126# so that you don't mix up when you need to list() them
114127
@@ -180,49 +193,53 @@ def cpp_string(s: str) -> str:
180193#
181194# This function is also used for a secondary purpose: the registration
182195# logic is also reused to implement per-operator registration.
183- def compute_type_method (
184- dispatch : Optional [str ], * ,
196+ @dataclass (frozen = True )
197+ class ComputeTypeMethod :
198+ dispatch : Optional [str ]
199+
185200 # TODO: Give more precise type Union[Literal[Target.DEFINITION,
186201 # Target.REGISTRATION]]; requires Literal from typing_extensions
187202 # which we don't have a dep for yet.
188- target : Target ,
203+ target : Target
204+
189205 # Selector object to determine which operators to generate
190206 # registration code for.
191207 selector : SelectiveBuilder
192- ) -> Callable [[NativeFunction ], Optional [str ]]:
193208
194- if dispatch is None :
195- assert target is Target .REGISTRATION
209+ def __post_init__ (self ) -> None :
210+ assert self .target is not Target .DECLARATION
211+ if self .dispatch is None :
212+ assert self .target is Target .REGISTRATION
196213
197- @with_native_function
198- def func ( f : NativeFunction ) -> Optional [str ]:
199- # Has to be here as mypy won't transfer asserts into closures
200- assert target is not Target .DECLARATION
214+ @method_with_native_function
215+ def __call__ ( self , f : NativeFunction ) -> Optional [str ]:
216+ # for mypy type refinement; would be fixed by TODO on target
217+ assert self . target is not Target .DECLARATION
201218
202- if dispatch is not None :
203- if dispatch not in f .dispatch :
219+ if self . dispatch is not None :
220+ if self . dispatch not in f .dispatch :
204221 return None
205222
206223 op_name = f"aten::{ f .func .name } "
207- if target is Target .REGISTRATION and not selector .is_operator_selected (op_name ):
224+ if self . target is Target .REGISTRATION and not self . selector .is_operator_selected (op_name ):
208225 return None
209226
210227 name = native .name (f .func )
211228 returns_type = native .returns_type (f .func .returns )
212229 args = native .arguments (f .func )
213230 args_str = ', ' .join (map (str , args ))
214- dispatch_to_all_backends = dispatch is not None and dispatch in KEYWORD_ALL_BACKENDS
231+ dispatch_to_all_backends = self . dispatch is not None and self . dispatch in KEYWORD_ALL_BACKENDS
215232
216- if target is Target .DEFINITION :
217- assert dispatch is not None
218- impl_name = f"at::native::{ f .dispatch [dispatch ]} "
233+ if self . target is Target .DEFINITION :
234+ assert self . dispatch is not None
235+ impl_name = f"at::native::{ f .dispatch [self . dispatch ]} "
219236
220237 args_exprs_str = ', ' .join (a .name for a in args )
221238
222239 return_kw = " return "
223240
224241 cuda_guard = ""
225- if dispatch_to_all_backends or 'CUDA' in dispatch :
242+ if dispatch_to_all_backends or 'CUDA' in self . dispatch :
226243 self_args = (a for a in f .func .arguments if a .name == "self" )
227244
228245 # There is precedence for which argument we use to do
@@ -249,7 +266,7 @@ def func(f: NativeFunction) -> Optional[str]:
249266 # works just as well.
250267 if f .device_guard and dispatch_to_all_backends and has_tensor_options :
251268 cuda_guard = cuda_guard_from_tensor_options
252- elif f .device_guard and dispatch is not None and 'CUDA' in dispatch and has_tensor_options :
269+ elif f .device_guard and self . dispatch is not None and 'CUDA' in self . dispatch and has_tensor_options :
253270 cuda_guard = f"""\
254271 globalContext().lazyInitCUDA();
255272 { cuda_guard_from_tensor_options }
@@ -269,16 +286,16 @@ def func(f: NativeFunction) -> Optional[str]:
269286}}
270287"""
271288
272- elif target is Target .REGISTRATION :
273- if dispatch is None :
289+ elif self . target is Target .REGISTRATION :
290+ if self . dispatch is None :
274291 return f'm.def({ cpp_string (str (f .func ))} );\n '
275292 elif f .manual_kernel_registration :
276293 return None
277294 else :
278295 if dispatch_to_all_backends :
279296 type_name = f'TypeDefault::{ name } '
280297 else :
281- type_name = f'{ dispatch } Type::{ name } '
298+ type_name = f'{ self . dispatch } Type::{ name } '
282299
283300 dispatcher_sig = DispatcherSignature .from_schema (f .func )
284301
@@ -302,21 +319,22 @@ def func(f: NativeFunction) -> Optional[str]:
302319 # in a TORCH_LIBRARY_FRAGMENT that does not have an ambient backend. So
303320 # the torch::dispatch specification here is important! See
304321 # Note [Redundancy in registration code is OK] for how we handle redundant info.
305- if dispatch is not None :
306- payload = f"torch::dispatch(DispatchKey::{ dispatch } ,\n { payload } )\n "
322+ if self . dispatch is not None :
323+ payload = f"torch::dispatch(DispatchKey::{ self . dispatch } ,\n { payload } )\n "
307324
308325 return f'm.impl("{ f .func .name } ",\n { payload } );\n '
309326 else :
310- assert_never (target )
311-
312- return func
327+ assert_never (self .target )
313328
314329# Generates Function.cpp and Function.h. These files provide the
315330# functional public C++ API, and the scaffolding to call into
316331# the dispatcher from these functions. See also compute_tensor_method.
317- def compute_function (* , target : Target ) -> Callable [[NativeFunction ], Optional [str ]]:
318- @with_native_function
319- def go (f : NativeFunction ) -> Optional [str ]:
332+ @dataclass (frozen = True )
333+ class ComputeFunction :
334+ target : Target
335+
336+ @method_with_native_function
337+ def __call__ (self , f : NativeFunction ) -> Optional [str ]:
320338 if f .manual_kernel_registration :
321339 return None
322340 if Variant .function not in f .variants :
@@ -326,13 +344,13 @@ def go(f: NativeFunction) -> Optional[str]:
326344
327345 sig_group = CppSignatureGroup .from_schema (f .func , method = False )
328346
329- if target is Target .DECLARATION :
347+ if self . target is Target .DECLARATION :
330348 result = f"CAFFE2_API { sig_group .signature .decl ()} ;\n "
331349 if sig_group .faithful_signature is not None :
332350 result += f"CAFFE2_API { sig_group .faithful_signature .decl ()} ;\n "
333351 return result
334352
335- assert target is Target .DEFINITION
353+ assert self . target is Target .DEFINITION
336354
337355 def generate_defn (sig : CppSignature ) -> str :
338356 dispatcher_sig = DispatcherSignature .from_schema (f .func )
@@ -357,14 +375,15 @@ def generate_defn(sig: CppSignature) -> str:
357375
358376 return result
359377
360- return go
361-
362378# Generates TensorBody.h (sic) and TensorMethods.cpp. These files provide the
363379# object-oriented (method-based) public C++ API, and the scaffolding to call into
364380# the dispatcher from these functions. See also compute_function.
365- def compute_tensor_method (* , target : Target ) -> Callable [[NativeFunction ], Optional [str ]]:
366- @with_native_function
367- def go (f : NativeFunction ) -> Optional [str ]:
381+ @dataclass (frozen = True )
382+ class ComputeTensorMethod :
383+ target : Target
384+
385+ @method_with_native_function
386+ def __call__ (self , f : NativeFunction ) -> Optional [str ]:
368387 if Variant .method not in f .variants :
369388 return None
370389
@@ -376,13 +395,13 @@ def go(f: NativeFunction) -> Optional[str]:
376395
377396 sig_group = CppSignatureGroup .from_schema (f .func , method = True )
378397
379- if target is Target .DECLARATION :
398+ if self . target is Target .DECLARATION :
380399 result = f"{ sig_group .signature .decl ()} const;\n "
381400 if sig_group .faithful_signature is not None :
382401 result += f"{ sig_group .faithful_signature .decl ()} const;\n "
383402 return result
384403
385- assert target is Target .DEFINITION
404+ assert self . target is Target .DEFINITION
386405
387406 def generate_defn (sig : CppSignature ) -> str :
388407 dispatcher_sig = DispatcherSignature .from_schema (f .func )
@@ -406,8 +425,6 @@ def generate_defn(sig: CppSignature) -> str:
406425
407426 return result
408427
409- return go
410-
411428# Generates ATenOpList.cpp, a runtime accessible list of all aten
412429# operators.
413430# TODO: This was historically used to help some JIT interop code
@@ -442,9 +459,12 @@ def compute_native_function_declaration(f: NativeFunction) -> List[str]:
442459# Generates BackendSelectRegister.cpp, a series of kernels which provide
443460# specialized computation of dispatch key for operator signatures which cannot
444461# be easily done automatically using templating.
445- def compute_backend_select (* , target : Target ) -> Callable [[NativeFunction ], Optional [str ]]:
446- @with_native_function
447- def go (f : NativeFunction ) -> Optional [str ]:
462+ @dataclass (frozen = True )
463+ class ComputeBackendSelect :
464+ target : Target
465+
466+ @method_with_native_function
467+ def __call__ (self , f : NativeFunction ) -> Optional [str ]:
448468 if str (f .func .name .name ).endswith ('_like' ) or str (f .func .name .name ).startswith ('new_' ):
449469 return None
450470
@@ -471,7 +491,7 @@ def go(f: NativeFunction) -> Optional[str]:
471491 dispatcher_exprs = native_sig .dispatcher_exprs ()
472492 dispatch_key = "options.computeDispatchKey()"
473493
474- if target is Target .DEFINITION :
494+ if self . target is Target .DEFINITION :
475495 # I don't think there's actually a good reason to generate
476496 # these two cases differently
477497 # The first case could probably be improved though- it calls dispatchTypeId(),
@@ -494,7 +514,7 @@ def go(f: NativeFunction) -> Optional[str]:
494514 return op.callWithDispatchKey(_dk, { ', ' .join (a .expr for a in dispatcher_exprs )} );
495515}}
496516"""
497- elif target is Target .REGISTRATION :
517+ elif self . target is Target .REGISTRATION :
498518 if local .use_c10_dispatcher () is UseC10Dispatcher .full :
499519 return f"""m.impl("aten::{ f .func .name } ", TORCH_FN({ name } ));"""
500520 elif local .use_c10_dispatcher () is UseC10Dispatcher .hacky_wrapper_for_legacy_signatures :
@@ -504,11 +524,10 @@ def go(f: NativeFunction) -> Optional[str]:
504524 else :
505525 assert local .use_c10_dispatcher () is UseC10Dispatcher .with_codegenerated_unboxing_wrapper
506526 return f"""m.impl_UNBOXED("aten::{ f .func .name } ", { name } );"""
507- elif target is Target .DECLARATION :
527+ elif self . target is Target .DECLARATION :
508528 raise AssertionError ()
509529 else :
510- assert_never (target )
511- return go
530+ assert_never (self .target )
512531
513532# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
514533#
@@ -993,12 +1012,11 @@ def make_file_manager(install_dir: str) -> FileManager:
9931012 '' ,
9941013 'Backend' : dispatch ,
9951014 'type_derived_method_definitions' : list (mapMaybe (
996- compute_type_method (dispatch , target = Target .DEFINITION , selector = selector ),
1015+ ComputeTypeMethod (dispatch , Target .DEFINITION , selector ),
9971016 native_functions
9981017 )),
9991018 'function_registrations' : list (mapMaybe (
1000- compute_type_method (
1001- dispatch , target = Target .REGISTRATION , selector = selector ),
1019+ ComputeTypeMethod (dispatch , Target .REGISTRATION , selector ),
10021020 native_functions
10031021 )),
10041022 })
@@ -1012,35 +1030,35 @@ def make_file_manager(install_dir: str) -> FileManager:
10121030 cpu_fm .write ('TypeDefault.cpp' , lambda : {
10131031 'type_method_definitions' :
10141032 list (mapMaybe (
1015- compute_type_method ('Math' , target = Target .DEFINITION , selector = selector ),
1033+ ComputeTypeMethod ('Math' , Target .DEFINITION , selector ),
10161034 native_functions )) +
10171035 list (mapMaybe (
1018- compute_type_method ('DefaultBackend' , target = Target .DEFINITION , selector = selector ),
1036+ ComputeTypeMethod ('DefaultBackend' , Target .DEFINITION , selector ),
10191037 native_functions )),
10201038
10211039 'function_registrations' : list (mapMaybe (
1022- compute_type_method (None , target = Target .REGISTRATION , selector = schema_selector ),
1040+ ComputeTypeMethod (None , Target .REGISTRATION , schema_selector ),
10231041 native_functions )),
10241042
10251043 'math_function_registrations' : list (mapMaybe (
1026- compute_type_method ('Math' , target = Target .REGISTRATION , selector = selector ),
1044+ ComputeTypeMethod ('Math' , Target .REGISTRATION , selector ),
10271045 native_functions )),
10281046
10291047 'default_backend_function_registrations' : list (mapMaybe (
1030- compute_type_method ('DefaultBackend' , target = Target .REGISTRATION , selector = selector ),
1048+ ComputeTypeMethod ('DefaultBackend' , Target .REGISTRATION , selector ),
10311049 native_functions )),
10321050 })
10331051 cpu_fm .write ('Functions.h' , lambda : {
1034- 'function_declarations' : list (mapMaybe (compute_function ( target = Target .DECLARATION ), native_functions )),
1052+ 'function_declarations' : list (mapMaybe (ComputeFunction ( Target .DECLARATION ), native_functions )),
10351053 })
10361054 cpu_fm .write ('Functions.cpp' , lambda : {
1037- 'function_definitions' : list (mapMaybe (compute_function ( target = Target .DEFINITION ), native_functions )),
1055+ 'function_definitions' : list (mapMaybe (ComputeFunction ( Target .DEFINITION ), native_functions )),
10381056 })
10391057 core_fm .write ('TensorBody.h' , lambda : {
1040- 'tensor_method_declarations' : list (mapMaybe (compute_tensor_method ( target = Target .DECLARATION ), native_functions )),
1058+ 'tensor_method_declarations' : list (mapMaybe (ComputeTensorMethod ( Target .DECLARATION ), native_functions )),
10411059 })
10421060 core_fm .write ('TensorMethods.cpp' , lambda : {
1043- 'tensor_method_definitions' : list (mapMaybe (compute_tensor_method ( target = Target .DEFINITION ), native_functions )),
1061+ 'tensor_method_definitions' : list (mapMaybe (ComputeTensorMethod ( Target .DEFINITION ), native_functions )),
10441062 })
10451063 core_fm .write ('ATenOpList.cpp' , lambda : {
10461064 'aten_ops' : list (mapMaybe (compute_aten_op , native_functions )),
@@ -1050,9 +1068,9 @@ def make_file_manager(install_dir: str) -> FileManager:
10501068 })
10511069 cpu_fm .write ('BackendSelectRegister.cpp' , lambda : {
10521070 'backend_select_method_definitions' :
1053- list (mapMaybe (compute_backend_select ( target = Target .DEFINITION ), native_functions )),
1071+ list (mapMaybe (ComputeBackendSelect ( Target .DEFINITION ), native_functions )),
10541072 'backend_select_function_registrations' :
1055- list (mapMaybe (compute_backend_select ( target = Target .REGISTRATION ), native_functions )),
1073+ list (mapMaybe (ComputeBackendSelect ( Target .REGISTRATION ), native_functions )),
10561074 })
10571075
10581076 cpu_fm .write ('Declarations.yaml' , lambda : format_yaml ([compute_declaration_yaml (f ) for f in native_functions ]))
0 commit comments