Skip to content

Commit 26220da

Browse files
feat: add aifoundry to openai prompt (#2404)
* add aifoundry to openai prompt * modify OpenAIPrompt * style
1 parent aac2ed6 commit 26220da

File tree

4 files changed

+50
-13
lines changed

4 files changed

+50
-13
lines changed

cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/aifoundry/AIFoundryChatCompletion.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@ trait HasAIFoundryTextParamsExtended extends HasOpenAITextParamsExtended {
1717
val model = new ServiceParam[String](
1818
this, "model", "The name of the model", isRequired = true)
1919

20-
GlobalParams.registerParam(model, OpenAIDeploymentNameKey)
21-
2220
def getModel: String = getScalarParam(model)
2321

2422
def setModel(v: String): this.type = setScalarParam(model, v)

cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@ import org.apache.spark.sql.{Column, DataFrame, Dataset, Row, functions => F, ty
1717
import org.apache.spark.sql.catalyst.encoders.RowEncoder
1818
import org.apache.spark.sql.functions.{col, udf}
1919
import org.apache.spark.sql.types.{DataType, StructField, StructType}
20+
import com.microsoft.azure.synapse.ml.services.aifoundry.{AIFoundryChatCompletion, HasAIFoundryTextParamsExtended}
2021

2122
import scala.collection.JavaConverters._
2223

2324
object OpenAIPrompt extends ComplexParamsReadable[OpenAIPrompt]
2425

2526
class OpenAIPrompt(override val uid: String) extends Transformer
27+
with HasAIFoundryTextParamsExtended
2628
with HasOpenAITextParamsExtended with HasMessagesInput
2729
with HasErrorCol with HasOutputCol
2830
with HasURL with HasCustomCogServiceDomain with ConcurrencyParams
@@ -122,6 +124,10 @@ class OpenAIPrompt(override val uid: String) extends Transformer
122124
setUrl(s"https://$v.openai.azure.com/" + urlPath.stripPrefix("/"))
123125
}
124126

127+
def setAIFoundryCustomServiceName(v: String): this.type = {
128+
setUrl(s"https://$v.services.ai.azure.com/" + urlPath.stripPrefix("/"))
129+
}
130+
125131
private val localParamNames = Seq(
126132
"promptTemplate", "outputCol", "postProcessing", "postProcessingOptions", "dropPrompt", "dropMessages",
127133
"systemPrompt")
@@ -221,19 +227,26 @@ class OpenAIPrompt(override val uid: String) extends Transformer
221227
}
222228
}
223229

230+
private[openai] def hasAIFoundryModel: Boolean = this.isDefined(model)
231+
232+
//deployment name can be set by user, it doesn't have to match with model name
224233
private val legacyModels = Set("ada", "babbage", "curie", "davinci",
225234
"text-ada-001", "text-babbage-001", "text-curie-001", "text-davinci-002", "text-davinci-003",
226235
"code-cushman-001", "code-davinci-002")
227236

228237
private def openAICompletion: OpenAIServicesBase = {
229238

230239
val completion: OpenAIServicesBase =
231-
if (legacyModels.contains(getDeploymentName)) {
240+
if (hasAIFoundryModel) {
241+
new AIFoundryChatCompletion()
242+
}
243+
else if (legacyModels.contains(getDeploymentName)) {
232244
new OpenAICompletion()
233245
}
234246
else {
235247
new OpenAIChatCompletion()
236248
}
249+
237250
// apply all parameters
238251
extractParamMap().toSeq
239252
.filter(p => !localParamNames.contains(p.param.name) && completion.hasParam(p.param.name))
@@ -247,6 +260,8 @@ class OpenAIPrompt(override val uid: String) extends Transformer
247260
openAICompletion match {
248261
case chatCompletion: OpenAIChatCompletion =>
249262
chatCompletion.prepareEntity(r)
263+
case chatCompletion: AIFoundryChatCompletion =>
264+
chatCompletion.prepareEntity(r)
250265
case completion: OpenAICompletion =>
251266
completion.prepareEntity(r)
252267
}
@@ -270,6 +285,10 @@ class OpenAIPrompt(override val uid: String) extends Transformer
270285
chatCompletion
271286
.transformSchema(schema.add(getMessagesCol, StructType(Seq())))
272287
.add(getPostProcessing, getParser.outputSchema)
288+
case chatCompletion: AIFoundryChatCompletion =>
289+
chatCompletion
290+
.transformSchema(schema.add(getMessagesCol, StructType(Seq())))
291+
.add(getPostProcessing, getParser.outputSchema)
273292
case completion: OpenAICompletion =>
274293
completion
275294
.transformSchema(schema)

cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/aifoundry/AIFoundryChatCompletionSuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import org.scalactic.Equality
1414
trait AIFoundryAPIKey {
1515
lazy val aiFoundryAPIKey: String = sys.env.getOrElse("AI_FOUNDRY_API_KEY", Secrets.AIFoundryApiKey)
1616
lazy val aiFoundryServiceName: String = sys.env.getOrElse("AI_FOUNDRY_SERVICE_NAME", "synapseml-ai-foundry-resource")
17-
lazy val modelName: String = "Phi-4-mini-instruct"
17+
lazy val aiFoundryModelName: String = "Llama-3.3-70B-Instruct" //"Phi-4-mini-instruct"
1818
}
1919

2020
class AIFoundryChatCompletionSuite extends TransformerFuzzing[AIFoundryChatCompletion] with AIFoundryAPIKey with Flaky {
@@ -31,7 +31,7 @@ class AIFoundryChatCompletionSuite extends TransformerFuzzing[AIFoundryChatCompl
3131
.setTopP(0.1)
3232
.setPresencePenalty(0.0)
3333
.setFrequencyPenalty(0.0)
34-
.setModel(modelName)
34+
.setModel(aiFoundryModelName)
3535
.setSubscriptionKey(aiFoundryAPIKey)
3636

3737
lazy val goodDf: DataFrame = Seq(
@@ -153,7 +153,7 @@ class AIFoundryChatCompletionSuite extends TransformerFuzzing[AIFoundryChatCompl
153153

154154
val completion = new AIFoundryChatCompletion()
155155
.setCustomServiceName(aiFoundryServiceName)
156-
.setModel(modelName)
156+
.setModel(aiFoundryModelName)
157157
.setApiVersion("2024-05-01-preview")
158158
.setMaxTokens(500)
159159
.setOutputCol("out")

cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,18 @@
33

44
package com.microsoft.azure.synapse.ml.services.openai
55

6-
import com.microsoft.azure.synapse.ml.Secrets.getAccessToken
6+
import com.microsoft.azure.synapse.ml.Secrets.{AIFoundryApiKey, getAccessToken}
77
import com.microsoft.azure.synapse.ml.core.test.base.Flaky
88
import com.microsoft.azure.synapse.ml.core.test.fuzzing.{TestObject, TransformerFuzzing}
99
import org.apache.spark.ml.util.MLReadable
1010
import org.apache.spark.sql.{DataFrame, Row}
1111
import org.apache.spark.sql.functions.col
1212
import org.scalactic.Equality
13+
import com.microsoft.azure.synapse.ml.services.aifoundry.AIFoundryAPIKey
1314

14-
class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIKey with Flaky {
15+
class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIKey
16+
with AIFoundryAPIKey
17+
with Flaky {
1518

1619
import spark.implicits._
1720

@@ -28,11 +31,18 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
2831
.setOutputCol("outParsed")
2932
.setTemperature(0)
3033

34+
lazy val aiFoundryPrompt: OpenAIPrompt = new OpenAIPrompt()
35+
.setSubscriptionKey(aiFoundryAPIKey)
36+
.setApiVersion("2024-05-01-preview")
37+
.setModel(aiFoundryModelName)
38+
.setAIFoundryCustomServiceName(aiFoundryServiceName)
39+
.setOutputCol("outParsed")
40+
.setTemperature(0)
41+
3142
lazy val df: DataFrame = Seq(
3243
("apple", "fruits"),
3344
("mercedes", "cars"),
34-
("cake", "dishes"),
35-
(null, "none") //scalastyle:ignore null
45+
("cake", "dishes")
3646
).toDF("text", "category")
3747

3848
test("RAI Usage") {
@@ -50,13 +60,23 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
5060

5161
test("Basic Usage") {
5262
val nonNullCount = prompt
53-
.setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ")
63+
.setPromptTemplate("give me a comma separated list of 5 {category}, starting with {text} ")
5464
.setPostProcessing("csv")
5565
.transform(df)
5666
.select("outParsed")
5767
.collect()
5868
.count(r => Option(r.getSeq[String](0)).isDefined)
69+
assert(nonNullCount == 3)
70+
}
5971

72+
test("Basic Usage AI Foundry") {
73+
val nonNullCount = aiFoundryPrompt
74+
.setPromptTemplate("give me a comma separated list of 5 {category}, starting with {text} ")
75+
.setPostProcessing("csv")
76+
.transform(df)
77+
.select("outParsed")
78+
.collect()
79+
.count(r => Option(r.getSeq[String](0)).isDefined)
6080
assert(nonNullCount == 3)
6181
}
6282

@@ -84,7 +104,7 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
84104

85105
test("Basic Usage - Gpt 4") {
86106
val nonNullCount = promptGpt4
87-
.setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ")
107+
.setPromptTemplate("give me a comma separated list of 5 {category}, starting with {text} ")
88108
.setPostProcessing("csv")
89109
.transform(df)
90110
.select("outParsed")
@@ -189,7 +209,7 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
189209
.setCustomHeaders(customHeadersValues)
190210
}
191211

192-
customPromptGpt4.setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ")
212+
customPromptGpt4.setPromptTemplate("give me a comma separated list of 5 {category}, starting with {text} ")
193213
.setPostProcessing("csv")
194214
.transform(df)
195215
.select("outParsed")

0 commit comments

Comments
 (0)