Skip to content

Commit 6b06095

Browse files
authored
wrongly assumed add op had only 2 operands (#1618)
1 parent 0c10130 commit 6b06095

File tree

5 files changed

+54
-37
lines changed

5 files changed

+54
-37
lines changed

src/arithmetic/algebraic_expression/algebraic_expression_optimization.c

Lines changed: 29 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -112,48 +112,45 @@ static bool __AlgebraicExpression_MulOverAdd(AlgebraicExpression **root) {
112112
else if((_AlgebraicExpression_IsAdditionNode(l) && !_AlgebraicExpression_IsAdditionNode(r)) ||
113113
(_AlgebraicExpression_IsAdditionNode(r) && !_AlgebraicExpression_IsAdditionNode(l))) {
114114

115-
// Disconnect left and right children from root.
115+
// disconnect left and right children from root
116116
r = _AlgebraicExpression_OperationRemoveDest((*root));
117117
l = _AlgebraicExpression_OperationRemoveDest((*root));
118118
ASSERT(AlgebraicExpression_ChildCount(*root) == 0);
119119

120-
AlgebraicExpression *add = AlgebraicExpression_NewOperation(AL_EXP_ADD);
121-
AlgebraicExpression *lMul = AlgebraicExpression_NewOperation(AL_EXP_MUL);
122-
AlgebraicExpression *rMul = AlgebraicExpression_NewOperation(AL_EXP_MUL);
123-
124-
AlgebraicExpression_AddChild(add, lMul);
125-
AlgebraicExpression_AddChild(add, rMul);
126-
127120
AlgebraicExpression *A;
128121
AlgebraicExpression *B;
129-
AlgebraicExpression *C;
122+
AlgebraicExpression *add = AlgebraicExpression_NewOperation(AL_EXP_ADD);
130123

131124
if(_AlgebraicExpression_IsAdditionNode(l)) {
132-
// Lefthand side is addition.
133-
// (A + B) * C = (A * C) + (B * C)
134-
135-
A = _AlgebraicExpression_OperationRemoveSource(l);
136-
B = _AlgebraicExpression_OperationRemoveDest(l);
137-
C = r;
138-
125+
// lefthand side is addition
126+
// (A + B + C) * D = (A * D) + (B * D) + (C * D)
127+
B = r;
128+
uint child_count = AlgebraicExpression_ChildCount(l);
129+
for(uint i = 0; i < child_count; i++) {
130+
A = _AlgebraicExpression_OperationRemoveDest(l);
131+
AlgebraicExpression *mul = AlgebraicExpression_NewOperation(AL_EXP_MUL);
132+
AlgebraicExpression_AddChild(mul, A);
133+
if(i == 0) AlgebraicExpression_AddChild(mul, B);
134+
else AlgebraicExpression_AddChild(mul, AlgebraicExpression_Clone(B));
135+
AlgebraicExpression_AddChild(add, mul);
136+
}
137+
ASSERT(AlgebraicExpression_ChildCount(l) == 0);
139138
AlgebraicExpression_Free(l);
140-
AlgebraicExpression_AddChild(lMul, A);
141-
AlgebraicExpression_AddChild(lMul, C);
142-
AlgebraicExpression_AddChild(rMul, B);
143-
AlgebraicExpression_AddChild(rMul, AlgebraicExpression_Clone(C));
144139
} else {
145-
// Righthand side is addition.
146-
// C * (A + B) = (C * A) + (C * B)
147-
148-
A = _AlgebraicExpression_OperationRemoveSource(r);
149-
B = _AlgebraicExpression_OperationRemoveDest(r);
150-
C = l;
151-
140+
// righthand side is addition
141+
// D * (A + B + C) = (D * A) + (D * B) + (D * C)
142+
A = l;
143+
uint child_count = AlgebraicExpression_ChildCount(r);
144+
for(uint i = 0; i < child_count; i++) {
145+
B = _AlgebraicExpression_OperationRemoveDest(r);
146+
AlgebraicExpression *mul = AlgebraicExpression_NewOperation(AL_EXP_MUL);
147+
if(i == 0) AlgebraicExpression_AddChild(mul, A);
148+
else AlgebraicExpression_AddChild(mul, AlgebraicExpression_Clone(A));
149+
AlgebraicExpression_AddChild(mul, B);
150+
AlgebraicExpression_AddChild(add, mul);
151+
}
152+
ASSERT(AlgebraicExpression_ChildCount(r) == 0);
152153
AlgebraicExpression_Free(r);
153-
AlgebraicExpression_AddChild(lMul, C);
154-
AlgebraicExpression_AddChild(lMul, A);
155-
AlgebraicExpression_AddChild(rMul, AlgebraicExpression_Clone(C));
156-
AlgebraicExpression_AddChild(rMul, B);
157154
}
158155
// Free original root and overwrite it with new addition root.
159156
AlgebraicExpression_Free(*root);
@@ -163,7 +160,7 @@ static bool __AlgebraicExpression_MulOverAdd(AlgebraicExpression **root) {
163160
}
164161
}
165162

166-
// Recurse.
163+
// recurse
167164
uint child_count = AlgebraicExpression_ChildCount(*root);
168165
for(uint i = 0; i < child_count; i++) {
169166
if(__AlgebraicExpression_MulOverAdd((*root)->operation.children + i)) return true;

tests/flow/test_bidirectional_traversals.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,3 +352,4 @@ def test13_multiple_bidirectional_edges(self):
352352
['v3', 'v1'],
353353
['v3', 'v3']]
354354
self.env.assertEquals(actual_result.result_set, expected_result)
355+

tests/flow/test_bulk_insertion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ def test09_large_bulk_insert(self):
376376
redis_con.ping()
377377
t1 = time.time() - t0
378378
# Verify that pinging the server takes less than 1 second during bulk insertion
379-
self.env.assertLess(t1, 1)
379+
self.env.assertLess(t1, 2)
380380
ping_count += 1
381381

382382
thread.join()

tests/flow/test_relation_patterns.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,14 @@
77

88
from base import FlowTestsBase
99

10-
redis_graph = None
11-
1210
GRAPH_ID = "G"
11+
redis_con = None
12+
redis_graph = None
1313

1414
class testRelationPattern(FlowTestsBase):
1515
def __init__(self):
1616
self.env = Env(decodeResponses=True)
17+
global redis_con
1718
global redis_graph
1819
redis_con = self.env.getConnection()
1920
redis_graph = Graph(GRAPH_ID, redis_con)
@@ -287,3 +288,21 @@ def test09_transposed_elem_order(self):
287288
actual_result = g.query(query)
288289
self.env.assertEquals(actual_result.result_set, expected_result)
289290

291+
def test10_triple_edge_type(self):
292+
# Construct a simple graph:
293+
# (A)-[X]->(B)
294+
# (A)-[Y]->(C)
295+
# (A)-[Z]->(D)
296+
g = Graph("triple_edge_type", redis_con)
297+
q = "CREATE(a:A), (b:B), (c:C), (d:D), (a)-[:X]->(b), (a)-[:Y]->(c), (a)-[:Z]->(d)"
298+
g.query(q)
299+
300+
labels = ['X', 'Y', 'Z']
301+
expected_result = [['B'], ['C'], ['D']]
302+
303+
q = "MATCH (a)-[:{L0}|:{L1}|:{L2}]->(b) RETURN labels(b) AS label ORDER BY label"
304+
import itertools
305+
for perm in itertools.permutations(labels):
306+
res = g.query(q.format(L0=perm[0], L1=perm[1], L2=perm[2]))
307+
self.env.assertEquals(res.result_set, expected_result)
308+

tests/unit/test_algebraic_expression.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -841,13 +841,13 @@ TEST_F(AlgebraicExpressionTest, ExpTransform_A_Times_B_Plus_C) {
841841
ASSERT_TRUE(leftLeft->type == AL_OPERAND && leftLeft->operand.matrix == A);
842842

843843
AlgebraicExpression *leftRight = rootLeftChild->operation.children[1];
844-
ASSERT_TRUE(leftRight->type == AL_OPERAND && leftRight->operand.matrix == B);
844+
ASSERT_TRUE(leftRight->type == AL_OPERAND && leftRight->operand.matrix == C);
845845

846846
AlgebraicExpression *rightLeft = rootRightChild->operation.children[0];
847847
ASSERT_TRUE(rightLeft->type == AL_OPERAND && rightLeft->operand.matrix == A);
848848

849849
AlgebraicExpression *rightRight = rootRightChild->operation.children[1];
850-
ASSERT_TRUE(rightRight->type == AL_OPERAND && rightRight->operand.matrix == C);
850+
ASSERT_TRUE(rightRight->type == AL_OPERAND && rightRight->operand.matrix == B);
851851

852852
raxFree(matrices);
853853
GrB_Matrix_free(&A);

0 commit comments

Comments
 (0)