1717
1818package org .apache .mxnet
1919
20- import org .apache .mxnet .init .Base ._
21- import org .apache .mxnet .utils .CToScalaUtils
2220import java .io ._
2321import java .security .MessageDigest
2422
@@ -29,13 +27,11 @@ import scala.collection.mutable.{ArrayBuffer, ListBuffer}
2927 * Two file namely: SymbolAPIBase.scala and NDArrayAPIBase.scala
3028 * The code will be executed during Macros stage and file live in Core stage
3129 */
32- private [mxnet] object APIDocGenerator {
33- case class absClassArg (argName : String , argType : String , argDesc : String , isOptional : Boolean )
34- case class absClassFunction (name : String , desc : String ,
35- listOfArgs : List [absClassArg], returnType : String )
30+ private [mxnet] object APIDocGenerator extends GeneratorBase {
31+ type absClassArg = Arg
32+ type absClassFunction = Func
3633
37-
38- def main (args : Array [String ]) : Unit = {
34+ def main (args : Array [String ]): Unit = {
3935 val FILE_PATH = args(0 )
4036 val hashCollector = ListBuffer [String ]()
4137 hashCollector += absClassGen(FILE_PATH , true )
@@ -47,68 +43,70 @@ private[mxnet] object APIDocGenerator{
4743 val finalHash = hashCollector.mkString(" \n " )
4844 }
4945
50- def MD5Generator (input : String ) : String = {
46+ def MD5Generator (input : String ): String = {
5147 val md = MessageDigest .getInstance(" MD5" )
5248 md.update(input.getBytes(" UTF-8" ))
5349 val digest = md.digest()
5450 org.apache.commons.codec.binary.Base64 .encodeBase64URLSafeString(digest)
5551 }
5652
57- def absRndClassGen (FILE_PATH : String , isSymbol : Boolean ) : String = {
58- typeSafeClassGen(
59- getSymbolNDArrayMethods(isSymbol)
60- .filter(f => f.name.startsWith(" _random" ) || f.name.startsWith(" _sample" ))
61- .map(f => f.copy(name = f.name.stripPrefix(" _" ))),
53+ def absRndClassGen (FILE_PATH : String , isSymbol : Boolean ): String = {
54+ val funcs = getSymbolNDArrayMethods(isSymbol)
55+ .filter(f => f.name.startsWith(" _sample_" ) || f.name.startsWith(" _random_" ))
56+ .map(f => f.copy(name = f.name.stripPrefix(" _" )))
57+ val body = funcs.map(func => {
58+ val scalaDoc = generateAPIDocFromBackend(func)
59+ val decl = generateRandomAPISignature(func, isSymbol)
60+ s " $scalaDoc\n $decl"
61+ })
62+ writeFile(
6263 FILE_PATH ,
6364 if (isSymbol) " SymbolRandomAPIBase" else " NDArrayRandomAPIBase" ,
64- isSymbol
65- )
65+ body)
6666 }
6767
68- def absClassGen (FILE_PATH : String , isSymbol : Boolean ) : String = {
68+ def absClassGen (FILE_PATH : String , isSymbol : Boolean ): String = {
6969 val notGenerated = Set (" Custom" )
70- typeSafeClassGen(
71- getSymbolNDArrayMethods(isSymbol)
72- .filterNot(_.name.startsWith(" _" ))
73- .filterNot(ele => notGenerated.contains(ele.name)),
70+ val funcs = getSymbolNDArrayMethods(isSymbol)
71+ .filterNot(_.name.startsWith(" _" ))
72+ .filterNot(ele => notGenerated.contains(ele.name))
73+ val body = funcs.map(func => {
74+ val scalaDoc = generateAPIDocFromBackend(func)
75+ val decl = generateAPISignature(func, isSymbol)
76+ s " $scalaDoc\n $decl"
77+ })
78+ writeFile(
7479 FILE_PATH ,
7580 if (isSymbol) " SymbolAPIBase" else " NDArrayAPIBase" ,
76- isSymbol
77- )
81+ body)
7882 }
7983
80- def typeSafeClassGen (absClassFunctions : Seq [absClassFunction], FILE_PATH : String ,
81- packageName : String , isSymbol : Boolean ): String = {
82- val absFuncs = absClassFunctions
83- .map(absClassFunction => {
84- val scalaDoc = generateAPIDocFromBackend(absClassFunction)
85- val defBody = generateAPISignature(absClassFunction, isSymbol)
86- s " $scalaDoc\n $defBody"
87- })
88- writeFile(FILE_PATH , packageName, absFuncs)
89- }
90-
91- def nonTypeSafeClassGen (FILE_PATH : String , isSymbol : Boolean ) : String = {
92- // scalastyle:off
84+ def nonTypeSafeClassGen (FILE_PATH : String , isSymbol : Boolean ): String = {
9385 val absClassFunctions = getSymbolNDArrayMethods(isSymbol)
9486 val absFuncs = absClassFunctions
9587 .filterNot(_.name.startsWith(" _" ))
9688 .map(absClassFunction => {
97- val scalaDoc = generateAPIDocFromBackend(absClassFunction, false )
98- if (isSymbol) {
99- val defBody = s " def ${absClassFunction.name}(name : String = null, attr : Map[String, String] = null)(args : org.apache.mxnet.Symbol*)(kwargs : Map[String, Any] = null): org.apache.mxnet.Symbol "
100- s " $scalaDoc\n $defBody"
101- } else {
102- val defBodyWithKwargs = s " def ${absClassFunction.name}(kwargs: Map[String, Any] = null)(args: Any*) : org.apache.mxnet.NDArrayFuncReturn "
103- val defBody = s " def ${absClassFunction.name}(args: Any*) : org.apache.mxnet.NDArrayFuncReturn "
104- s " $scalaDoc\n $defBodyWithKwargs\n $scalaDoc\n $defBody"
105- }
106- })
89+ val scalaDoc = generateAPIDocFromBackend(absClassFunction, false )
90+ if (isSymbol) {
91+ val defBody =
92+ s " def ${absClassFunction.name}(name : String = null, attr : Map[String, String] = null) " +
93+ s " (args : org.apache.mxnet.Symbol*)(kwargs : Map[String, Any] = null): " +
94+ s " org.apache.mxnet.Symbol "
95+ s " $scalaDoc\n $defBody"
96+ } else {
97+ val defBodyWithKwargs = s " def ${absClassFunction.name}(kwargs: Map[String, Any] = null) " +
98+ s " (args: Any*): " +
99+ s " org.apache.mxnet.NDArrayFuncReturn "
100+ val defBody = s " def ${absClassFunction.name}(args: Any*): " +
101+ s " org.apache.mxnet.NDArrayFuncReturn "
102+ s " $scalaDoc\n $defBodyWithKwargs\n $scalaDoc\n $defBody"
103+ }
104+ })
107105 val packageName = if (isSymbol) " SymbolBase" else " NDArrayBase"
108106 writeFile(FILE_PATH , packageName, absFuncs)
109107 }
110108
111- def writeFile (FILE_PATH : String , packageName : String , absFuncs : Seq [String ]): String = {
109+ def writeFile (FILE_PATH : String , packageName : String , body : Seq [String ]): String = {
112110 val apacheLicence =
113111 """ /*
114112 |* Licensed to the Apache Software Foundation (ASF) under one or more
@@ -137,7 +135,7 @@ private[mxnet] object APIDocGenerator{
137135 | $packageDef
138136 | $imports
139137 | $absClassDef {
140- | ${absFuncs .mkString(" \n " )}
138+ | ${body .mkString(" \n " )}
141139 |} """ .stripMargin
142140 val pw = new PrintWriter (new File (FILE_PATH + s " $packageName.scala " ))
143141 pw.write(finalStr)
@@ -146,20 +144,15 @@ private[mxnet] object APIDocGenerator{
146144 }
147145
148146 // Generate ScalaDoc type
149- def generateAPIDocFromBackend (func : absClassFunction, withParam : Boolean = true ) : String = {
147+ def generateAPIDocFromBackend (func : absClassFunction, withParam : Boolean = true ): String = {
150148 val desc = ArrayBuffer [String ]()
151149 desc += " * <pre>"
152- func.desc.split(" \n " ).foreach({ currStr =>
150+ func.desc.split(" \n " ).foreach({ currStr =>
153151 desc += s " * $currStr"
154152 })
155153 desc += " * </pre>"
156154 val params = func.listOfArgs.map({ absClassArg =>
157- val currArgName = absClassArg.argName match {
158- case " var" => " vari"
159- case " type" => " typeOf"
160- case _ => absClassArg.argName
161- }
162- s " * @param $currArgName\t\t ${absClassArg.argDesc}"
155+ s " * @param ${absClassArg.safeArgName}\t\t ${absClassArg.argDesc}"
163156 })
164157 val returnType = s " * @return ${func.returnType}"
165158 if (withParam) {
@@ -169,64 +162,31 @@ private[mxnet] object APIDocGenerator{
169162 }
170163 }
171164
172- def generateAPISignature (func : absClassFunction, isSymbol : Boolean ) : String = {
173- var argDef = ListBuffer [String ]()
174- func.listOfArgs.foreach(absClassArg => {
175- val currArgName = absClassArg.argName match {
176- case " var" => " vari"
177- case " type" => " typeOf"
178- case _ => absClassArg.argName
179- }
180- if (absClassArg.isOptional) {
181- argDef += s " $currArgName : Option[ ${absClassArg.argType}] = None "
182- }
183- else {
184- argDef += s " $currArgName : ${absClassArg.argType}"
185- }
186- })
187- var returnType = func.returnType
165+ def generateRandomAPISignature (func : absClassFunction, isSymbol : Boolean ): String = {
166+ generateAPISignature(func, isSymbol)
167+ }
168+
169+ def generateAPISignature (func : absClassFunction, isSymbol : Boolean ): String = {
170+ val argDef = ListBuffer [String ]()
171+
172+ argDef ++= buildArgDefs(func)
173+
188174 if (isSymbol) {
189175 argDef += " name : String = null"
190176 argDef += " attr : Map[String, String] = null"
191177 } else {
192178 argDef += " out : Option[NDArray] = None"
193- returnType = " org.apache.mxnet.NDArrayFuncReturn"
194179 }
180+
181+ val returnType = func.returnType
182+
195183 val experimentalTag = " @Experimental"
196184 s " $experimentalTag\n def ${func.name} ( ${argDef.mkString(" , " )}) : $returnType"
197185 }
198186
199187 // List and add all the atomic symbol functions to current module.
200- private def getSymbolNDArrayMethods (isSymbol : Boolean ): List [absClassFunction] = {
201- val opNames = ListBuffer .empty[String ]
202- val returnType = if (isSymbol) " Symbol" else " NDArray"
203- _LIB.mxListAllOpNames(opNames)
204- // TODO: Add '_linalg_', '_sparse_', '_image_' support
205- // TODO: Add Filter to the same location in case of refactor
206- opNames.map(opName => {
207- val opHandle = new RefLong
208- _LIB.nnGetOpHandle(opName, opHandle)
209- makeAtomicSymbolFunction(opHandle.value, opName, " org.apache.mxnet." + returnType)
210- }).toList
188+ private def getSymbolNDArrayMethods (isSymbol : Boolean ): List [absClassFunction] = {
189+ buildFunctionList(isSymbol)
211190 }
212191
213- // Create an atomic symbol function by handle and function name.
214- private def makeAtomicSymbolFunction (handle : SymbolHandle , aliasName : String , returnType : String )
215- : absClassFunction = {
216- val name = new RefString
217- val desc = new RefString
218- val keyVarNumArgs = new RefString
219- val numArgs = new RefInt
220- val argNames = ListBuffer .empty[String ]
221- val argTypes = ListBuffer .empty[String ]
222- val argDescs = ListBuffer .empty[String ]
223-
224- _LIB.mxSymbolGetAtomicSymbolInfo(
225- handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs)
226- val argList = argNames zip argTypes zip argDescs map { case ((argName, argType), argDesc) =>
227- val typeAndOption = CToScalaUtils .argumentCleaner(argName, argType, returnType)
228- absClassArg(argName, typeAndOption._1, argDesc, typeAndOption._2)
229- }
230- absClassFunction(aliasName, desc.value, argList.toList, returnType)
231- }
232192}
0 commit comments