Skip to content

Commit 275cd68

Browse files
committed
Add extra test for friend trees
1 parent e50caed commit 275cd68

File tree

1 file changed

+71
-22
lines changed

1 file changed

+71
-22
lines changed

bindings/experimental/distrdf/test/test_friendinfo.py

Lines changed: 71 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,37 +9,33 @@
99
class FriendInfoTest(unittest.TestCase):
1010
"""Unit test for the FriendInfo class"""
1111

12-
def create_parent_tree(self):
12+
def create_parent_tree(self, treename, filename):
1313
"""Creates a .root file with the parent TTree"""
14-
f = ROOT.TFile("treeparent.root", "recreate")
15-
t = ROOT.TTree("T", "test friend trees")
14+
f = ROOT.TFile(filename, "recreate")
15+
t = ROOT.TTree(treename, "parent tree")
1616

17-
x = array("f", [0])
18-
t.Branch("x", x, "x/F")
17+
x = array("i", [0])
18+
t.Branch("x", x, "x/I")
1919

20-
r = ROOT.TRandom()
21-
# The parent will have a gaussian distribution with mean 10 and
22-
# standard deviation 1
23-
for _ in range(10000):
24-
x[0] = r.Gaus(10, 1)
20+
for i in range(9):
21+
x[0] = i
2522
t.Fill()
2623

2724
f.Write()
2825
f.Close()
2926

30-
def create_friend_tree(self):
27+
def create_friend_tree(self, treename, filename):
3128
"""Creates a .root file with the friend TTree"""
32-
ff = ROOT.TFile("treefriend.root", "recreate")
33-
tf = ROOT.TTree("TF", "tree friend")
29+
ff = ROOT.TFile(filename, "recreate")
30+
tf = ROOT.TTree(treename, "friend tree")
3431

35-
x = array("f", [0])
36-
tf.Branch("x", x, "x/F")
32+
y = array("i", [0])
33+
tf.Branch("y", y, "y/I")
3734

38-
r = ROOT.TRandom()
3935
# The friend will have a gaussian distribution with mean 20 and
4036
# standard deviation 1
41-
for _ in range(10000):
42-
x[0] = r.Gaus(20, 1)
37+
for i in range(3):
38+
y[0] = i
4339
tf.Fill()
4440

4541
ff.Write()
@@ -50,24 +46,23 @@ def test_friend_info_with_ttree(self):
5046
Check that FriendInfo correctly stores information about the friend
5147
trees
5248
"""
53-
self.create_parent_tree()
54-
self.create_friend_tree()
55-
5649
# Parent Tree
5750
base_tree_name = "T"
5851
base_tree_filename = "treeparent.root"
52+
self.create_parent_tree(base_tree_name, base_tree_filename)
5953
basetree = ROOT.TChain(base_tree_name)
6054
basetree.Add(base_tree_filename)
6155

6256
# Friend Tree
6357
friend_tree_name = "TF"
6458
friend_tree_alias = "TF"
6559
friend_tree_filename = "treefriend.root"
60+
self.create_friend_tree(friend_tree_name, friend_tree_filename)
6661
friendtree = ROOT.TChain(friend_tree_name)
6762
friendtree.Add(friend_tree_filename)
6863

6964
# Add friendTree to the parent
70-
basetree.AddFriend(friendtree)
65+
basetree.AddFriend(friendtree, friend_tree_alias)
7166

7267
# Instantiate head node of the graph with the base TTree
7368
headnode = Factory.get_headnode(basetree)
@@ -89,3 +84,57 @@ def test_friend_info_with_ttree(self):
8984
# Remove unnecessary .root files
9085
os.remove(base_tree_filename)
9186
os.remove(friend_tree_filename)
87+
88+
89+
def test_friend_info_chain_with_subnames(self):
90+
"""
91+
Check that FriendInfo correctly stores information about the friend
92+
trees
93+
"""
94+
# Parent Tree
95+
parent_name = "treeparent"
96+
parent_filename = "treeparent.root"
97+
self.create_parent_tree(parent_name, parent_filename)
98+
parenttree = ROOT.TChain(parent_name)
99+
parenttree.Add(parent_filename)
100+
101+
# Friend chain
102+
friendchainname = "treefriend"
103+
friendchain = ROOT.TChain(friendchainname)
104+
actualfriendchainfilenames = []
105+
actualfriendchainsubnames = []
106+
for i in range(1,4):
107+
friend_name = "treefriend" + str(i)
108+
friend_filename = friend_name + ".root"
109+
actualfriendchainfilenames.append(friend_filename)
110+
actualfriendchainsubnames.append(friend_name)
111+
self.create_friend_tree(friend_name, friend_filename)
112+
friendchain.Add(friend_filename + "/" + friend_name)
113+
114+
# Add friendTree to the parent
115+
parenttree.AddFriend(friendchain)
116+
117+
# Instantiate head node of the graph with the base TTree
118+
headnode = Factory.get_headnode(parenttree)
119+
120+
# Retrieve information about friends
121+
friendnamesalias, friendfilenames, friendchainsubnames = headnode.get_friendinfo()
122+
123+
print(friendnamesalias)
124+
print(friendfilenames)
125+
print(friendchainsubnames)
126+
# Check that FriendInfo has non-empty lists
127+
self.assertIsNotNone(friendnamesalias)
128+
self.assertIsNotNone(friendfilenames)
129+
self.assertIsNotNone(friendchainsubnames)
130+
131+
# Check that the three lists with treenames, filenames and subnames are populated
132+
# as expected.
133+
self.assertTupleEqual(friendnamesalias, ((friendchainname, friendchainname),))
134+
self.assertTupleEqual(friendfilenames, (tuple(actualfriendchainfilenames),) )
135+
self.assertTupleEqual(friendchainsubnames, (tuple(actualfriendchainsubnames),) )
136+
137+
# Remove unnecessary .root files
138+
os.remove(parent_filename)
139+
for filename in actualfriendchainfilenames:
140+
os.remove(filename)

0 commit comments

Comments
 (0)