@@ -66,7 +66,7 @@ test_that("spark.gbt", {
6666 # label must be binary - GBTClassifier currently only supports binary classification.
6767 iris2 <- iris [iris $ Species != " virginica" , ]
6868 data <- suppressWarnings(createDataFrame(iris2 ))
69- model <- spark.gbt(data , Species ~ Petal_Length + Petal_Width , " classification" )
69+ model <- spark.gbt(data , Species ~ Petal_Length + Petal_Width , " classification" , seed = 12 )
7070 stats <- summary(model )
7171 expect_equal(stats $ numFeatures , 2 )
7272 expect_equal(stats $ numTrees , 20 )
@@ -94,7 +94,7 @@ test_that("spark.gbt", {
9494
9595 iris2 $ NumericSpecies <- ifelse(iris2 $ Species == " setosa" , 0 , 1 )
9696 df <- suppressWarnings(createDataFrame(iris2 ))
97- m <- spark.gbt(df , NumericSpecies ~ . , type = " classification" )
97+ m <- spark.gbt(df , NumericSpecies ~ . , type = " classification" , seed = 12 )
9898 s <- summary(m )
9999 # test numeric prediction values
100100 expect_equal(iris2 $ NumericSpecies , as.double(collect(predict(m , df ))$ prediction ))
@@ -106,7 +106,7 @@ test_that("spark.gbt", {
106106 if (windows_with_hadoop()) {
107107 data <- read.df(absoluteSparkPath(" data/mllib/sample_binary_classification_data.txt" ),
108108 source = " libsvm" )
109- model <- spark.gbt(data , label ~ features , " classification" )
109+ model <- spark.gbt(data , label ~ features , " classification" , seed = 12 )
110110 expect_equal(summary(model )$ numFeatures , 692 )
111111 }
112112
@@ -117,10 +117,11 @@ test_that("spark.gbt", {
117117 trainidxs <- base :: sample(nrow(data ), nrow(data ) * 0.7 )
118118 traindf <- as.DataFrame(data [trainidxs , ])
119119 testdf <- as.DataFrame(rbind(data [- trainidxs , ], c(0 , " the other" )))
120- model <- spark.gbt(traindf , clicked ~ . , type = " classification" )
120+ model <- spark.gbt(traindf , clicked ~ . , type = " classification" , seed = 23 )
121121 predictions <- predict(model , testdf )
122122 expect_error(collect(predictions ))
123- model <- spark.gbt(traindf , clicked ~ . , type = " classification" , handleInvalid = " keep" )
123+ model <- spark.gbt(traindf , clicked ~ . , type = " classification" , handleInvalid = " keep" ,
124+ seed = 23 )
124125 predictions <- predict(model , testdf )
125126 expect_equal(class(collect(predictions )$ clicked [1 ]), " character" )
126127})
@@ -129,7 +130,7 @@ test_that("spark.randomForest", {
129130 # regression
130131 data <- suppressWarnings(createDataFrame(longley ))
131132 model <- spark.randomForest(data , Employed ~ . , " regression" , maxDepth = 5 , maxBins = 16 ,
132- numTrees = 1 )
133+ numTrees = 1 , seed = 1 )
133134
134135 predictions <- collect(predict(model , data ))
135136 expect_equal(predictions $ prediction , c(60.323 , 61.122 , 60.171 , 61.187 ,
@@ -177,7 +178,7 @@ test_that("spark.randomForest", {
177178 # classification
178179 data <- suppressWarnings(createDataFrame(iris ))
179180 model <- spark.randomForest(data , Species ~ Petal_Length + Petal_Width , " classification" ,
180- maxDepth = 5 , maxBins = 16 )
181+ maxDepth = 5 , maxBins = 16 , seed = 123 )
181182
182183 stats <- summary(model )
183184 expect_equal(stats $ numFeatures , 2 )
@@ -215,7 +216,7 @@ test_that("spark.randomForest", {
215216 iris $ NumericSpecies <- lapply(iris $ Species , labelToIndex )
216217 data <- suppressWarnings(createDataFrame(iris [- 5 ]))
217218 model <- spark.randomForest(data , NumericSpecies ~ Petal_Length + Petal_Width , " classification" ,
218- maxDepth = 5 , maxBins = 16 )
219+ maxDepth = 5 , maxBins = 16 , seed = 123 )
219220 stats <- summary(model )
220221 expect_equal(stats $ numFeatures , 2 )
221222 expect_equal(stats $ numTrees , 20 )
@@ -234,28 +235,29 @@ test_that("spark.randomForest", {
234235 traindf <- as.DataFrame(data [trainidxs , ])
235236 testdf <- as.DataFrame(rbind(data [- trainidxs , ], c(0 , " the other" )))
236237 model <- spark.randomForest(traindf , clicked ~ . , type = " classification" ,
237- maxDepth = 10 , maxBins = 10 , numTrees = 10 )
238+ maxDepth = 10 , maxBins = 10 , numTrees = 10 , seed = 123 )
238239 predictions <- predict(model , testdf )
239240 expect_error(collect(predictions ))
240241 model <- spark.randomForest(traindf , clicked ~ . , type = " classification" ,
241242 maxDepth = 10 , maxBins = 10 , numTrees = 10 ,
242- handleInvalid = " keep" )
243+ handleInvalid = " keep" , seed = 123 )
243244 predictions <- predict(model , testdf )
244245 expect_equal(class(collect(predictions )$ clicked [1 ]), " character" )
245246
246247 # spark.randomForest classification can work on libsvm data
247248 if (windows_with_hadoop()) {
248249 data <- read.df(absoluteSparkPath(" data/mllib/sample_multiclass_classification_data.txt" ),
249250 source = " libsvm" )
250- model <- spark.randomForest(data , label ~ features , " classification" )
251+ model <- spark.randomForest(data , label ~ features , " classification" , seed = 123 )
251252 expect_equal(summary(model )$ numFeatures , 4 )
252253 }
253254})
254255
255256test_that(" spark.decisionTree" , {
256257 # regression
257258 data <- suppressWarnings(createDataFrame(longley ))
258- model <- spark.decisionTree(data , Employed ~ . , " regression" , maxDepth = 5 , maxBins = 16 )
259+ model <- spark.decisionTree(data , Employed ~ . , " regression" , maxDepth = 5 , maxBins = 16 ,
260+ seed = 42 )
259261
260262 predictions <- collect(predict(model , data ))
261263 expect_equal(predictions $ prediction , c(60.323 , 61.122 , 60.171 , 61.187 ,
@@ -288,7 +290,7 @@ test_that("spark.decisionTree", {
288290 # classification
289291 data <- suppressWarnings(createDataFrame(iris ))
290292 model <- spark.decisionTree(data , Species ~ Petal_Length + Petal_Width , " classification" ,
291- maxDepth = 5 , maxBins = 16 )
293+ maxDepth = 5 , maxBins = 16 , seed = 43 )
292294
293295 stats <- summary(model )
294296 expect_equal(stats $ numFeatures , 2 )
@@ -325,7 +327,7 @@ test_that("spark.decisionTree", {
325327 iris $ NumericSpecies <- lapply(iris $ Species , labelToIndex )
326328 data <- suppressWarnings(createDataFrame(iris [- 5 ]))
327329 model <- spark.decisionTree(data , NumericSpecies ~ Petal_Length + Petal_Width , " classification" ,
328- maxDepth = 5 , maxBins = 16 )
330+ maxDepth = 5 , maxBins = 16 , seed = 44 )
329331 stats <- summary(model )
330332 expect_equal(stats $ numFeatures , 2 )
331333 expect_equal(stats $ maxDepth , 5 )
@@ -339,7 +341,7 @@ test_that("spark.decisionTree", {
339341 if (windows_with_hadoop()) {
340342 data <- read.df(absoluteSparkPath(" data/mllib/sample_multiclass_classification_data.txt" ),
341343 source = " libsvm" )
342- model <- spark.decisionTree(data , label ~ features , " classification" )
344+ model <- spark.decisionTree(data , label ~ features , " classification" , seed = 45 )
343345 expect_equal(summary(model )$ numFeatures , 4 )
344346 }
345347
@@ -351,11 +353,11 @@ test_that("spark.decisionTree", {
351353 traindf <- as.DataFrame(data [trainidxs , ])
352354 testdf <- as.DataFrame(rbind(data [- trainidxs , ], c(0 , " the other" )))
353355 model <- spark.decisionTree(traindf , clicked ~ . , type = " classification" ,
354- maxDepth = 5 , maxBins = 16 )
356+ maxDepth = 5 , maxBins = 16 , seed = 46 )
355357 predictions <- predict(model , testdf )
356358 expect_error(collect(predictions ))
357359 model <- spark.decisionTree(traindf , clicked ~ . , type = " classification" ,
358- maxDepth = 5 , maxBins = 16 , handleInvalid = " keep" )
360+ maxDepth = 5 , maxBins = 16 , handleInvalid = " keep" , seed = 46 )
359361 predictions <- predict(model , testdf )
360362 expect_equal(class(collect(predictions )$ clicked [1 ]), " character" )
361363})
0 commit comments