Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 6cbd132

Browse files
committed
refactoring : extract GeneratorBase
1 parent 053f9fa commit 6cbd132

File tree

6 files changed

+248
-357
lines changed

6 files changed

+248
-357
lines changed

scala-package/core/src/test/scala/org/apache/mxnet/SymbolSuite.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.mxnet
1919

2020
import org.scalatest.{BeforeAndAfterAll, FunSuite}
2121

22+
2223
class SymbolSuite extends FunSuite with BeforeAndAfterAll {
2324

2425
test("symbol compose") {

scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala

Lines changed: 63 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717

1818
package org.apache.mxnet
1919

20-
import org.apache.mxnet.init.Base._
21-
import org.apache.mxnet.utils.CToScalaUtils
2220
import java.io._
2321
import java.security.MessageDigest
2422

@@ -29,13 +27,11 @@ import scala.collection.mutable.{ArrayBuffer, ListBuffer}
2927
* Two file namely: SymbolAPIBase.scala and NDArrayAPIBase.scala
3028
* The code will be executed during Macros stage and file live in Core stage
3129
*/
32-
private[mxnet] object APIDocGenerator{
33-
case class absClassArg(argName : String, argType : String, argDesc : String, isOptional : Boolean)
34-
case class absClassFunction(name : String, desc : String,
35-
listOfArgs: List[absClassArg], returnType : String)
30+
private[mxnet] object APIDocGenerator extends GeneratorBase {
31+
type absClassArg = Arg
32+
type absClassFunction = Func
3633

37-
38-
def main(args: Array[String]) : Unit = {
34+
def main(args: Array[String]): Unit = {
3935
val FILE_PATH = args(0)
4036
val hashCollector = ListBuffer[String]()
4137
hashCollector += absClassGen(FILE_PATH, true)
@@ -47,68 +43,70 @@ private[mxnet] object APIDocGenerator{
4743
val finalHash = hashCollector.mkString("\n")
4844
}
4945

50-
def MD5Generator(input : String) : String = {
46+
def MD5Generator(input: String): String = {
5147
val md = MessageDigest.getInstance("MD5")
5248
md.update(input.getBytes("UTF-8"))
5349
val digest = md.digest()
5450
org.apache.commons.codec.binary.Base64.encodeBase64URLSafeString(digest)
5551
}
5652

57-
def absRndClassGen(FILE_PATH : String, isSymbol : Boolean) : String = {
58-
typeSafeClassGen(
59-
getSymbolNDArrayMethods(isSymbol)
60-
.filter(f => f.name.startsWith("_random") || f.name.startsWith("_sample"))
61-
.map(f => f.copy(name = f.name.stripPrefix("_"))),
53+
def absRndClassGen(FILE_PATH: String, isSymbol: Boolean): String = {
54+
val funcs = getSymbolNDArrayMethods(isSymbol)
55+
.filter(f => f.name.startsWith("_sample_") || f.name.startsWith("_random_"))
56+
.map(f => f.copy(name = f.name.stripPrefix("_")))
57+
val body = funcs.map(func => {
58+
val scalaDoc = generateAPIDocFromBackend(func)
59+
val decl = generateRandomAPISignature(func, isSymbol)
60+
s"$scalaDoc\n$decl"
61+
})
62+
writeFile(
6263
FILE_PATH,
6364
if (isSymbol) "SymbolRandomAPIBase" else "NDArrayRandomAPIBase",
64-
isSymbol
65-
)
65+
body)
6666
}
6767

68-
def absClassGen(FILE_PATH : String, isSymbol : Boolean) : String = {
68+
def absClassGen(FILE_PATH: String, isSymbol: Boolean): String = {
6969
val notGenerated = Set("Custom")
70-
typeSafeClassGen(
71-
getSymbolNDArrayMethods(isSymbol)
72-
.filterNot(_.name.startsWith("_"))
73-
.filterNot(ele => notGenerated.contains(ele.name)),
70+
val funcs = getSymbolNDArrayMethods(isSymbol)
71+
.filterNot(_.name.startsWith("_"))
72+
.filterNot(ele => notGenerated.contains(ele.name))
73+
val body = funcs.map(func => {
74+
val scalaDoc = generateAPIDocFromBackend(func)
75+
val decl = generateAPISignature(func, isSymbol)
76+
s"$scalaDoc\n$decl"
77+
})
78+
writeFile(
7479
FILE_PATH,
7580
if (isSymbol) "SymbolAPIBase" else "NDArrayAPIBase",
76-
isSymbol
77-
)
81+
body)
7882
}
7983

80-
def typeSafeClassGen(absClassFunctions: Seq[absClassFunction], FILE_PATH: String,
81-
packageName: String, isSymbol: Boolean): String = {
82-
val absFuncs = absClassFunctions
83-
.map(absClassFunction => {
84-
val scalaDoc = generateAPIDocFromBackend(absClassFunction)
85-
val defBody = generateAPISignature(absClassFunction, isSymbol)
86-
s"$scalaDoc\n$defBody"
87-
})
88-
writeFile(FILE_PATH, packageName, absFuncs)
89-
}
90-
91-
def nonTypeSafeClassGen(FILE_PATH : String, isSymbol : Boolean) : String = {
92-
// scalastyle:off
84+
def nonTypeSafeClassGen(FILE_PATH: String, isSymbol: Boolean): String = {
9385
val absClassFunctions = getSymbolNDArrayMethods(isSymbol)
9486
val absFuncs = absClassFunctions
9587
.filterNot(_.name.startsWith("_"))
9688
.map(absClassFunction => {
97-
val scalaDoc = generateAPIDocFromBackend(absClassFunction, false)
98-
if (isSymbol) {
99-
val defBody = s"def ${absClassFunction.name}(name : String = null, attr : Map[String, String] = null)(args : org.apache.mxnet.Symbol*)(kwargs : Map[String, Any] = null): org.apache.mxnet.Symbol"
100-
s"$scalaDoc\n$defBody"
101-
} else {
102-
val defBodyWithKwargs = s"def ${absClassFunction.name}(kwargs: Map[String, Any] = null)(args: Any*) : org.apache.mxnet.NDArrayFuncReturn"
103-
val defBody = s"def ${absClassFunction.name}(args: Any*) : org.apache.mxnet.NDArrayFuncReturn"
104-
s"$scalaDoc\n$defBodyWithKwargs\n$scalaDoc\n$defBody"
105-
}
106-
})
89+
val scalaDoc = generateAPIDocFromBackend(absClassFunction, false)
90+
if (isSymbol) {
91+
val defBody =
92+
s"def ${absClassFunction.name}(name : String = null, attr : Map[String, String] = null)" +
93+
s"(args : org.apache.mxnet.Symbol*)(kwargs : Map[String, Any] = null): " +
94+
s"org.apache.mxnet.Symbol"
95+
s"$scalaDoc\n$defBody"
96+
} else {
97+
val defBodyWithKwargs = s"def ${absClassFunction.name}(kwargs: Map[String, Any] = null)" +
98+
s"(args: Any*): " +
99+
s"org.apache.mxnet.NDArrayFuncReturn"
100+
val defBody = s"def ${absClassFunction.name}(args: Any*): " +
101+
s"org.apache.mxnet.NDArrayFuncReturn"
102+
s"$scalaDoc\n$defBodyWithKwargs\n$scalaDoc\n$defBody"
103+
}
104+
})
107105
val packageName = if (isSymbol) "SymbolBase" else "NDArrayBase"
108106
writeFile(FILE_PATH, packageName, absFuncs)
109107
}
110108

111-
def writeFile(FILE_PATH: String, packageName: String, absFuncs: Seq[String]): String = {
109+
def writeFile(FILE_PATH: String, packageName: String, body: Seq[String]): String = {
112110
val apacheLicence =
113111
"""/*
114112
|* Licensed to the Apache Software Foundation (ASF) under one or more
@@ -137,7 +135,7 @@ private[mxnet] object APIDocGenerator{
137135
|$packageDef
138136
|$imports
139137
|$absClassDef {
140-
|${absFuncs.mkString("\n")}
138+
|${body.mkString("\n")}
141139
|}""".stripMargin
142140
val pw = new PrintWriter(new File(FILE_PATH + s"$packageName.scala"))
143141
pw.write(finalStr)
@@ -146,20 +144,15 @@ private[mxnet] object APIDocGenerator{
146144
}
147145

148146
// Generate ScalaDoc type
149-
def generateAPIDocFromBackend(func : absClassFunction, withParam : Boolean = true) : String = {
147+
def generateAPIDocFromBackend(func: absClassFunction, withParam: Boolean = true): String = {
150148
val desc = ArrayBuffer[String]()
151149
desc += " * <pre>"
152-
func.desc.split("\n").foreach({ currStr =>
150+
func.desc.split("\n").foreach({ currStr =>
153151
desc += s" * $currStr"
154152
})
155153
desc += " * </pre>"
156154
val params = func.listOfArgs.map({ absClassArg =>
157-
val currArgName = absClassArg.argName match {
158-
case "var" => "vari"
159-
case "type" => "typeOf"
160-
case _ => absClassArg.argName
161-
}
162-
s" * @param $currArgName\t\t${absClassArg.argDesc}"
155+
s" * @param ${absClassArg.safeArgName}\t\t${absClassArg.argDesc}"
163156
})
164157
val returnType = s" * @return ${func.returnType}"
165158
if (withParam) {
@@ -169,64 +162,31 @@ private[mxnet] object APIDocGenerator{
169162
}
170163
}
171164

172-
def generateAPISignature(func : absClassFunction, isSymbol : Boolean) : String = {
173-
var argDef = ListBuffer[String]()
174-
func.listOfArgs.foreach(absClassArg => {
175-
val currArgName = absClassArg.argName match {
176-
case "var" => "vari"
177-
case "type" => "typeOf"
178-
case _ => absClassArg.argName
179-
}
180-
if (absClassArg.isOptional) {
181-
argDef += s"$currArgName : Option[${absClassArg.argType}] = None"
182-
}
183-
else {
184-
argDef += s"$currArgName : ${absClassArg.argType}"
185-
}
186-
})
187-
var returnType = func.returnType
165+
def generateRandomAPISignature(func: absClassFunction, isSymbol: Boolean): String = {
166+
generateAPISignature(func, isSymbol)
167+
}
168+
169+
def generateAPISignature(func: absClassFunction, isSymbol: Boolean): String = {
170+
val argDef = ListBuffer[String]()
171+
172+
argDef ++= buildArgDefs(func)
173+
188174
if (isSymbol) {
189175
argDef += "name : String = null"
190176
argDef += "attr : Map[String, String] = null"
191177
} else {
192178
argDef += "out : Option[NDArray] = None"
193-
returnType = "org.apache.mxnet.NDArrayFuncReturn"
194179
}
180+
181+
val returnType = func.returnType
182+
195183
val experimentalTag = "@Experimental"
196184
s"$experimentalTag\ndef ${func.name} (${argDef.mkString(", ")}) : $returnType"
197185
}
198186

199187
// List and add all the atomic symbol functions to current module.
200-
private def getSymbolNDArrayMethods(isSymbol : Boolean): List[absClassFunction] = {
201-
val opNames = ListBuffer.empty[String]
202-
val returnType = if (isSymbol) "Symbol" else "NDArray"
203-
_LIB.mxListAllOpNames(opNames)
204-
// TODO: Add '_linalg_', '_sparse_', '_image_' support
205-
// TODO: Add Filter to the same location in case of refactor
206-
opNames.map(opName => {
207-
val opHandle = new RefLong
208-
_LIB.nnGetOpHandle(opName, opHandle)
209-
makeAtomicSymbolFunction(opHandle.value, opName, "org.apache.mxnet." + returnType)
210-
}).toList
188+
private def getSymbolNDArrayMethods(isSymbol: Boolean): List[absClassFunction] = {
189+
buildFunctionList(isSymbol)
211190
}
212191

213-
// Create an atomic symbol function by handle and function name.
214-
private def makeAtomicSymbolFunction(handle: SymbolHandle, aliasName: String, returnType : String)
215-
: absClassFunction = {
216-
val name = new RefString
217-
val desc = new RefString
218-
val keyVarNumArgs = new RefString
219-
val numArgs = new RefInt
220-
val argNames = ListBuffer.empty[String]
221-
val argTypes = ListBuffer.empty[String]
222-
val argDescs = ListBuffer.empty[String]
223-
224-
_LIB.mxSymbolGetAtomicSymbolInfo(
225-
handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs)
226-
val argList = argNames zip argTypes zip argDescs map { case ((argName, argType), argDesc) =>
227-
val typeAndOption = CToScalaUtils.argumentCleaner(argName, argType, returnType)
228-
absClassArg(argName, typeAndOption._1, argDesc, typeAndOption._2)
229-
}
230-
absClassFunction(aliasName, desc.value, argList.toList, returnType)
231-
}
232192
}
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
package org.apache.mxnet
2+
3+
import org.apache.mxnet.init.Base.{RefInt, RefLong, RefString, _LIB}
4+
import org.apache.mxnet.utils.{CToScalaUtils, OperatorBuildUtils}
5+
6+
import scala.collection.mutable.ListBuffer
7+
import scala.reflect.macros.blackbox
8+
9+
abstract class GeneratorBase {
10+
type Handle = Long
11+
12+
case class Arg(argName: String, argType: String, argDesc: String, isOptional: Boolean) {
13+
def safeArgName: String = argName match {
14+
case "var" => "vari"
15+
case "type" => "typeOf"
16+
case _ => argName
17+
}
18+
}
19+
20+
case class Func(name: String, desc: String, listOfArgs: List[Arg], returnType: String)
21+
22+
protected def buildFunctionList(isSymbol: Boolean): List[Func] = {
23+
val opNames = ListBuffer.empty[String]
24+
_LIB.mxListAllOpNames(opNames)
25+
opNames.map(opName => {
26+
val opHandle = new RefLong
27+
_LIB.nnGetOpHandle(opName, opHandle)
28+
makeAtomicFunction(opHandle.value, opName, isSymbol)
29+
}).toList
30+
}
31+
32+
protected def makeAtomicFunction(handle: Handle, aliasName: String, isSymbol: Boolean): Func = {
33+
val name = new RefString
34+
val desc = new RefString
35+
val keyVarNumArgs = new RefString
36+
val numArgs = new RefInt
37+
val argNames = ListBuffer.empty[String]
38+
val argTypes = ListBuffer.empty[String]
39+
val argDescs = ListBuffer.empty[String]
40+
41+
_LIB.mxSymbolGetAtomicSymbolInfo(
42+
handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs)
43+
val paramStr = OperatorBuildUtils.ctypes2docstring(argNames, argTypes, argDescs)
44+
val extraDoc: String = if (keyVarNumArgs.value != null && keyVarNumArgs.value.length > 0) {
45+
s"This function support variable length of positional input (${keyVarNumArgs.value})."
46+
} else {
47+
""
48+
}
49+
val realName = if (aliasName == name.value) "" else s"(a.k.a., ${name.value})"
50+
val docStr = s"$aliasName $realName\n${desc.value}\n\n$paramStr\n$extraDoc\n"
51+
// scalastyle:off println
52+
if (System.getenv("MXNET4J_PRINT_OP_DEF") != null
53+
&& System.getenv("MXNET4J_PRINT_OP_DEF").toLowerCase == "true") {
54+
println("Function definition:\n" + docStr)
55+
}
56+
// scalastyle:on println
57+
val argList = argNames zip argTypes zip argDescs map { case ((argName, argType), argDesc) =>
58+
val family = if(isSymbol) "org.apache.mxnet.Symbol" else "org.apache.mxnet.NDArray"
59+
val typeAndOption =
60+
CToScalaUtils.argumentCleaner(argName, argType, family)
61+
Arg(argName, typeAndOption._1, argDesc, typeAndOption._2)
62+
}
63+
val returnType = if(isSymbol) "org.apache.mxnet.Symbol" else "org.apache.mxnet.NDArrayFuncReturn"
64+
Func(aliasName, desc.value, argList.toList, returnType)
65+
}
66+
67+
/**
68+
* Generate class structure for all function APIs
69+
*
70+
* @param c
71+
* @param funcDef DefDef type of function definitions
72+
* @param annottees
73+
* @return
74+
*/
75+
protected def structGeneration(c: blackbox.Context)
76+
(funcDef: List[c.universe.DefDef], annottees: c.Expr[Any]*)
77+
: c.Expr[Any] = {
78+
import c.universe._
79+
val inputs = annottees.map(_.tree).toList
80+
// pattern match on the inputs
81+
val modDefs = inputs map {
82+
case ClassDef(mods, name, something, template) =>
83+
val q = template match {
84+
case Template(superMaybe, emptyValDef, defs) =>
85+
Template(superMaybe, emptyValDef, defs ++ funcDef)
86+
case ex =>
87+
throw new IllegalArgumentException(s"Invalid template: $ex")
88+
}
89+
ClassDef(mods, name, something, q)
90+
case ModuleDef(mods, name, template) =>
91+
val q = template match {
92+
case Template(superMaybe, emptyValDef, defs) =>
93+
Template(superMaybe, emptyValDef, defs ++ funcDef)
94+
case ex =>
95+
throw new IllegalArgumentException(s"Invalid template: $ex")
96+
}
97+
ModuleDef(mods, name, q)
98+
case ex =>
99+
throw new IllegalArgumentException(s"Invalid macro input: $ex")
100+
}
101+
// wrap the result up in an Expr, and return it
102+
val result = c.Expr(Block(modDefs, Literal(Constant())))
103+
result
104+
}
105+
106+
protected def buildArgDefs(func: Func): List[String] = {
107+
func.listOfArgs.map(arg =>
108+
if (arg.isOptional)
109+
s"${arg.safeArgName} : Option[${arg.argType}] = None"
110+
else
111+
s"${arg.safeArgName} : ${arg.argType}"
112+
)
113+
}
114+
115+
116+
}

0 commit comments

Comments
 (0)