1+ import os
12import unittest
23from array import array
34
@@ -38,10 +39,9 @@ def test_inmemory_tree(self):
3839 x [0 ] = i
3940 tree .Fill ()
4041
41- headnode = create_dummy_headnode (tree )
4242 with self .assertRaises (ROOT .std .runtime_error ):
4343 # Trees with no associated files are not supported
44- headnode . get_inputfiles ( )
44+ create_dummy_headnode ( tree )
4545
4646 def assertArgs (self , args_list1 , args_list2 ):
4747 """
@@ -85,9 +85,14 @@ def test_two_args(self):
8585 # RDataFrame constructor with 2nd argument as ROOT CPP Vector
8686 hn_3 = create_dummy_headnode ("treename" , reqd_vec )
8787
88- self .assertArgs (hn_1 .args , ["treename" , "file.root" ])
89- self .assertArgs (hn_2 .args , ["treename" , rdf_2_files ])
90- self .assertArgs (hn_3 .args , ["treename" , reqd_vec ])
88+ for hn in (hn_1 , hn_2 , hn_3 ):
89+ self .assertEqual (hn .treename , "treename" )
90+
91+ self .assertListEqual (hn_1 .inputfiles , ["file.root" ])
92+ self .assertListEqual (hn_2 .inputfiles , rdf_2_files )
93+ # hn_3 got file names as std::vector<std::string> but the TreeHeadNode
94+ # instance stores it as list[str]
95+ self .assertListEqual (hn_3 .inputfiles , rdf_2_files )
9196
9297 def test_three_args_with_single_file (self ):
9398 """Constructor with TTree, one input file and selected branches"""
@@ -104,8 +109,13 @@ def test_three_args_with_single_file(self):
104109 # RDataFrame constructor with 3rd argument as ROOT CPP Vector
105110 hn_2 = create_dummy_headnode ("treename" , "file.root" , reqd_vec )
106111
107- self .assertArgs (hn_1 .args , ["treename" , "file.root" , rdf_branches ])
108- self .assertArgs (hn_2 .args , ["treename" , "file.root" , reqd_vec ])
112+ for hn in (hn_1 , hn_2 ):
113+ self .assertEqual (hn .treename , "treename" )
114+ self .assertListEqual (hn .inputfiles , ["file.root" ])
115+
116+ self .assertListEqual (hn_1 .defaultbranches , rdf_branches )
117+ self .assertIsInstance (hn_2 .defaultbranches , type (reqd_vec ))
118+ self .assertListEqual (list (hn_2 .defaultbranches ), list (reqd_vec ))
109119
110120 def test_three_args_with_multiple_files (self ):
111121 """Constructor with TTree, list of input files and selected branches"""
@@ -138,12 +148,15 @@ def test_three_args_with_multiple_files(self):
138148 # CPP Vectors
139149 hn_4 = create_dummy_headnode ("treename" , reqd_files_vec , reqd_branches_vec )
140150
141- self .assertArgs (hn_1 .args , ["treename" , rdf_files , rdf_branches ])
142- self .assertArgs (hn_2 .args , ["treename" , rdf_files , reqd_branches_vec ])
143- self .assertArgs (hn_3 .args , ["treename" , reqd_files_vec , rdf_branches ])
144- self .assertArgs (
145- hn_4 .args , ["treename" , reqd_files_vec , reqd_branches_vec ])
151+ for hn in (hn_1 , hn_2 , hn_3 , hn_4 ):
152+ self .assertEqual (hn .treename , "treename" )
153+ self .assertListEqual (hn .inputfiles , rdf_files )
154+ self .assertListEqual (list (hn .defaultbranches ), rdf_branches )
146155
156+ self .assertIsInstance (hn_1 .defaultbranches , type (rdf_branches ))
157+ self .assertIsInstance (hn_3 .defaultbranches , type (rdf_branches ))
158+ self .assertIsInstance (hn_2 .defaultbranches , type (reqd_branches_vec ))
159+ self .assertIsInstance (hn_4 .defaultbranches , type (reqd_branches_vec ))
147160
148161class NumEntriesTest (unittest .TestCase ):
149162 """'get_num_entries' returns the number of entries in the input dataset"""
@@ -168,9 +181,9 @@ def test_num_entries_two_args_case(self):
168181 hn_1 = create_dummy_headnode ("tree" , ["data.root" ])
169182 hn_2 = create_dummy_headnode ("tree" , files_vec )
170183
171- self .assertEqual (hn .get_num_entries (), 1111 )
172- self .assertEqual (hn_1 .get_num_entries (), 1111 )
173- self .assertEqual (hn_2 .get_num_entries (), 1111 )
184+ self .assertEqual (hn .tree . GetEntries (), 1111 )
185+ self .assertEqual (hn_1 .tree . GetEntries (), 1111 )
186+ self .assertEqual (hn_2 .tree . GetEntries (), 1111 )
174187
175188 def test_num_entries_three_args_case (self ):
176189 """
@@ -190,17 +203,20 @@ def test_num_entries_three_args_case(self):
190203 hn_2 = create_dummy_headnode ("tree" , "data.root" , branches_vec_1 )
191204 hn_3 = create_dummy_headnode ("tree" , "data.root" , branches_vec_2 )
192205
193- self .assertEqual (hn .get_num_entries (), 1234 )
194- self .assertEqual (hn_1 .get_num_entries (), 1234 )
195- self .assertEqual (hn_2 .get_num_entries (), 1234 )
196- self .assertEqual (hn_3 .get_num_entries (), 1234 )
206+ self .assertEqual (hn .tree . GetEntries (), 1234 )
207+ self .assertEqual (hn_1 .tree . GetEntries (), 1234 )
208+ self .assertEqual (hn_2 .tree . GetEntries (), 1234 )
209+ self .assertEqual (hn_3 .tree . GetEntries (), 1234 )
197210
198211 def test_num_entries_with_ttree_arg (self ):
199212 """
200213 Ensure that the number of entries recorded are correct in the case
201214 of RDataFrame constructor with a TTree.
202215
203216 """
217+ filename = "test_num_entries_with_ttree_arg.root"
218+ f = ROOT .TFile (filename , "recreate" )
219+
204220 tree = ROOT .TTree ("tree" , "test" ) # Create tree
205221 v = ROOT .std .vector ("int" )(4 ) # Create a vector of 0s of size 4
206222 tree .Branch ("vectorb" , v ) # Create branch to hold the vector
@@ -209,6 +225,11 @@ def test_num_entries_with_ttree_arg(self):
209225 v [i ] = 1 # Change the vector element to 1
210226 tree .Fill () # Fill the tree with that element
211227
228+ f .Write ()
229+
212230 hn = create_dummy_headnode (tree )
213231
214- self .assertEqual (hn .get_num_entries (), 4 )
232+ self .assertEqual (hn .tree .GetEntries (), 4 )
233+
234+ f .Close ()
235+ os .remove (filename )
0 commit comments