@@ -10,19 +10,19 @@ import com.microsoft.azure.synapse.ml.fabric.FabricClient
1010import com .microsoft .azure .synapse .ml .io .http ._
1111import com .microsoft .azure .synapse .ml .logging .SynapseMLLogging
1212import com .microsoft .azure .synapse .ml .logging .common .PlatformDetails
13- import com .microsoft .azure .synapse .ml .param .{ GlobalKey , GlobalParams , HasGlobalParams , ServiceParam }
14- import com .microsoft .azure .synapse .ml .stages .{ DropColumns , Lambda }
13+ import com .microsoft .azure .synapse .ml .param .{GlobalKey , GlobalParams , HasGlobalParams , ServiceParam , TypedArrayParam }
14+ import com .microsoft .azure .synapse .ml .stages .{DropColumns , Lambda }
1515import org .apache .commons .lang .StringUtils
1616import org .apache .http .NameValuePair
17- import org .apache .http .client .methods .{ HttpEntityEnclosingRequestBase , HttpPost , HttpRequestBase }
17+ import org .apache .http .client .methods .{HttpEntityEnclosingRequestBase , HttpPost , HttpRequestBase }
1818import org .apache .http .client .utils .URLEncodedUtils
1919import org .apache .http .entity .AbstractHttpEntity
2020import org .apache .http .impl .client .CloseableHttpClient
2121import org .apache .spark .ml .param ._
22- import org .apache .spark .ml .{ ComplexParamsWritable , NamespaceInjections , PipelineModel , Transformer }
23- import org .apache .spark .sql .functions .{ col , lit , struct }
22+ import org .apache .spark .ml .{ComplexParamsWritable , NamespaceInjections , PipelineModel , Transformer }
23+ import org .apache .spark .sql .functions .{col , lit , struct }
2424import org .apache .spark .sql .types ._
25- import org .apache .spark .sql .{ DataFrame , Dataset , Row }
25+ import org .apache .spark .sql .{DataFrame , Dataset , Row }
2626import spray .json .DefaultJsonProtocol ._
2727
2828import java .net .URI
@@ -206,13 +206,37 @@ trait HasCustomHeaders extends HasServiceParams {
206206
207207 // For Pyspark compatability accept Java HashMap as input to parameter
208208 // py4J only natively supports conversions from Python Dict to Java HashMap
209- def setCustomHeaders (v : java.util.HashMap [String ,String ]): this .type = {
209+ def setCustomHeaders (v : java.util.HashMap [String , String ]): this .type = {
210210 setCustomHeaders(v.asScala.toMap)
211211 }
212+ }
213+
214+ trait HasTelemHeaders extends HasServiceParams {
215+
216+ private [ml] val telemHeaders = new ServiceParam [Map [String , String ]](
217+ this , " telemHeaders" , " Map of Custom Header Key-Value Tuples."
218+ )
219+
220+ private [ml] def setTelemHeaders (v : Map [String , String ]): this .type = {
221+ setScalarParam(telemHeaders, v)
222+ }
223+
224+ // For Pyspark compatability accept Java HashMap as input to parameter
225+ // py4J only natively supports conversions from Python Dict to Java HashMap
226+ private [ml] def setTelemHeaders (v : java.util.HashMap [String , String ]): this .type = {
227+ setTelemHeaders(v.asScala.toMap)
228+ }
229+
230+ setDefault(telemHeaders -> Left (Map (" x-ai-telemetry-properties" ->
231+ s """ {
232+ |"OriginatingService": "SynapseML",
233+ |"ClientArtifactType": "Spark",
234+ |"OperationName": " ${this .getClass.getName}"
235+ |} """ .stripMargin.replaceAll(" \n " , " " ))))
212236
213- def getCustomHeaders : Map [String , String ] = getScalarParam(customHeaders)
214237}
215238
239+
216240trait HasCustomCogServiceDomain extends Wrappable with HasURL with HasUrlPath {
217241 def setCustomServiceName (v : String ): this .type = {
218242 setUrl(s " https:// $v.cognitiveservices.azure.com/ " + urlPath.stripPrefix(" /" ))
@@ -281,7 +305,7 @@ object URLEncodingUtils {
281305}
282306
283307trait HasCognitiveServiceInput extends HasURL with HasSubscriptionKey with HasAADToken with HasCustomAuthHeader
284- with HasCustomHeaders with SynapseMLLogging {
308+ with HasCustomHeaders with HasTelemHeaders with SynapseMLLogging {
285309
286310 val customUrlRoot : Param [String ] = new Param [String ](
287311 this , " customUrlRoot" , " The custom URL root for the service. " +
@@ -334,7 +358,7 @@ trait HasCognitiveServiceInput extends HasURL with HasSubscriptionKey with HasAA
334358
335359 protected def getCustomAuthHeader (row : Row ): Option [String ] = {
336360 val providedCustomAuthHeader = getValueOpt(row, CustomAuthHeader )
337- if (providedCustomAuthHeader .isEmpty && PlatformDetails .runningOnFabric()) {
361+ if (providedCustomAuthHeader.isEmpty && PlatformDetails .runningOnFabric()) {
338362 logInfo(" Using Default AAD Token On Fabric" )
339363 Option (FabricClient .getCognitiveMWCTokenAuthHeader)
340364 } else {
@@ -362,6 +386,7 @@ trait HasCognitiveServiceInput extends HasURL with HasSubscriptionKey with HasAA
362386 val contentTypeValue = contentType(row)
363387 val customAuthHeaderOpt = getCustomAuthHeader(row)
364388 val customHeadersOpt = getCustomHeaders(row)
389+ val telemHeadersOpt = getValueOpt(row, telemHeaders)
365390
366391 if (subscriptionKeyOpt.nonEmpty) {
367392 headers += (subscriptionKeyHeaderName -> getValue(row, subscriptionKey))
@@ -376,6 +401,7 @@ trait HasCognitiveServiceInput extends HasURL with HasSubscriptionKey with HasAA
376401 headers += (" x-ms-workload-resource-moniker" -> UUID .randomUUID().toString)
377402 }
378403 }
404+
379405 if (customHeadersOpt.nonEmpty) {
380406 customHeadersOpt.foreach { m =>
381407 m.foreach { case (headerName, headerValue) =>
@@ -384,10 +410,17 @@ trait HasCognitiveServiceInput extends HasURL with HasSubscriptionKey with HasAA
384410 }
385411 }
386412
413+ if (telemHeadersOpt.nonEmpty) {
414+ telemHeadersOpt.foreach { m =>
415+ m.foreach { case (headerName, headerValue) =>
416+ headers += (headerName -> headerValue)
417+ }
418+ }
419+ }
420+
387421 if (addContentType && ! StringUtils .isEmpty(contentTypeValue)) {
388422 headers += (" Content-Type" -> contentTypeValue)
389423 }
390-
391424 new scala.collection.immutable.TreeMap [String , String ]() ++ headers
392425 }
393426
@@ -514,7 +547,7 @@ abstract class CognitiveServicesBaseNoHandler(val uid: String) extends Transform
514547 errorCol -> (this .uid + " _error" )
515548 )
516549
517- if (PlatformDetails .runningOnFabric()) {
550+ if (PlatformDetails .runningOnFabric()) {
518551 setDefaultInternalEndpoint(FabricClient .MLWorkloadEndpointML )
519552 }
520553
0 commit comments