@@ -558,7 +558,7 @@ def _scala_test_impl(ctx):
558558 return _scala_binary_common (ctx , cjars , rjars )
559559
560560def _gen_test_suite_flags_based_on_prefixes_and_suffixes (ctx , archive ):
561- return struct (suite_class = "io.bazel.rulesscala.test_discovery.DiscoveredTestSuite" ,
561+ return struct (testSuiteFlag = "-Dbazel.test_suite= io.bazel.rulesscala.test_discovery.DiscoveredTestSuite" ,
562562 archiveFlag = "-Dbazel.discover.classes.archive.file.path=%s" % archive .short_path ,
563563 prefixesFlag = "-Dbazel.discover.classes.prefixes=%s" % "," .join (ctx .attr .prefixes ),
564564 suffixesFlag = "-Dbazel.discover.classes.suffixes=%s" % "," .join (ctx .attr .suffixes ),
@@ -568,20 +568,19 @@ def _scala_junit_test_impl(ctx):
568568 if (not (ctx .attr .prefixes ) and not (ctx .attr .suffixes )):
569569 fail ("Setting at least one of the attributes ('prefixes','suffixes') is required" )
570570 jars = _collect_jars_from_common_ctx (ctx ,
571- extra_deps = [ctx .attr ._junit , ctx .attr ._hamcrest , ctx .attr ._suite ],
571+ extra_deps = [ctx .attr ._junit , ctx .attr ._hamcrest , ctx .attr ._suite , ctx . attr . _bazel_test_runner ],
572572 )
573573 (cjars , rjars ) = (jars .compiletime , jars .runtime )
574574
575575 rjars += [ctx .outputs .jar ]
576576
577577 test_suite = _gen_test_suite_flags_based_on_prefixes_and_suffixes (ctx , ctx .outputs .jar )
578- launcherJvmFlags = ["-ea" , test_suite .archiveFlag , test_suite .prefixesFlag , test_suite .suffixesFlag , test_suite .printFlag ]
578+ launcherJvmFlags = ["-ea" , test_suite .archiveFlag , test_suite .prefixesFlag , test_suite .suffixesFlag , test_suite .printFlag , test_suite . testSuiteFlag ]
579579 _write_launcher (
580580 ctx = ctx ,
581581 rjars = rjars ,
582- main_class = "org. junit.runner.JUnitCore " ,
582+ main_class = "com.google.testing. junit.runner.BazelTestRunner " ,
583583 jvm_flags = launcherJvmFlags + ctx .attr .jvm_flags ,
584- args = test_suite .suite_class ,
585584 )
586585
587586 return _scala_binary_common (ctx , cjars , rjars )
@@ -865,6 +864,7 @@ scala_junit_test = rule(
865864 "_junit" : attr .label (default = Label ("//external:io_bazel_rules_scala/dependency/junit/junit" )),
866865 "_hamcrest" : attr .label (default = Label ("//external:io_bazel_rules_scala/dependency/hamcrest/hamcrest_core" )),
867866 "_suite" : attr .label (default = Label ("//src/java/io/bazel/rulesscala/test_discovery:test_discovery" )),
867+ "_bazel_test_runner" : attr .label (default = Label ("@bazel_tools//tools/jdk:TestRunner_deploy.jar" ), allow_files = True ),
868868 },
869869 outputs = {
870870 "jar" : "%{name}.jar" ,
0 commit comments