Skip to content

Conversation

@jakevdp
Copy link
Member

@jakevdp jakevdp commented Aug 22, 2011

This is a complete re-write of BallTree, as recently discussed on the mailing list.

Main advantages:

  • All ball tree information is stored in numpy arrays, rather than in
    dynamically allocated C arrays. This allows a constructed
    BallTree to be pickled and unpickled without the need to rebuild the tree.
  • Multiple distance metrics are supported. Currently, it can
    efficiently use any minkowski p-distance (similar to
    scipy.spatial.cKDTree). Because the distance functions are written in
    cython, there's potential to more easily add support for other metrics.
  • Compared to the current C++ BallTree implementation, the new code is
    faster by a factor of 5-8 for building the tree, and about 30-50% for
    querying the tree, depending on the type of input data

Disadvantages:

  • The code is not as easy to understand as the c++ implementation. A
    class-based approach is much more intuitive. Pseudo-code for a class-based
    implementation is included in the implementation notes.
  • The code uses mostly raw pointers, because numpy array-slicing led to
    too much call overhead. This makes the code more difficult to modify,
    and I couldn't use numpy routines for things like searching and sorting
    arrays - I wrote fast cython routines for a few of these basic algorithms.
    Because they're purpose-built for this implementation, they perform well.
  • Allocating all memory before building the tree leads to less
    flexibility. The old C++ implementation would easily allow extensions
    to different construction methods, inline data addition and subtraction, etc. Adding
    these sorts of extensions to the new module would be much more difficult.
  • Memory allocation is not exact (fixed in commit below)

As noted on the mailing list, I think the advantages far outweigh the disadvantages. Currently all tests pass: I think it's ready to go, barring any further input. For quick comparison of execution times between this implementation and the old implementation, see http://github.com/jakevdp/pyTree

@ogrisel
Copy link
Member

ogrisel commented Aug 22, 2011

After a quick overview this looks really good. I like the NodeInfo cast that makes the code more readable :) About the memory allocation have you done some checks on various dataset sizes and structures to check that is does behave too badly in practice?

@jakevdp
Copy link
Member Author

jakevdp commented Aug 22, 2011

I have checked the memory issue: see calc_tree_size.py in http://github.com/jakevdp/pyTree
It plots the number of nodes, the required size of the array, and the size of the allocated memory as a function of n_features and leaf_size.
The over-allocation is a small percentage of the total, except for very specific values of n_features with leaf_size near 1. It's a rare problem, but it might warrant more thought.

@jakevdp
Copy link
Member Author

jakevdp commented Aug 23, 2011

I was thinking about this while on my bicycle (it's when I get my best thinking done) and had a flash of insight: if we adjust the definition of leaf_size, then the memory allocation becomes much simpler. Previously, I had called a node a leaf node whenever n_points <= leaf_size. This can often lead to partially filled levels and slightly unbalanced trees. If instead, we just guarantee that leaf_size <= n_points <= 2 * leaf_size for any leaf node, the tree can be balanced and the total size is relatively easy to compute. I've implemented this in the most recent commit: details can be found in the doc in ball_tree.pyx.
The net result is that now there is no wasted memory. I tested it on training sets with 1 <= n_samples <= 2000 and 1 <= leaf_size <= 20 and it's working as expected.

@fabianp
Copy link
Member

fabianp commented Aug 23, 2011

For the common routines (copy, sorting, searching), it is also possible to use the Numpy C API [0], thus avoiding the python overhead. I'm thinking of PyArray_ArgSort, PyArray_Max, PyArray_Copy, etc ...

[0] http://docs.scipy.org/doc/numpy/reference/c-api.array.html

@jakevdp
Copy link
Member Author

jakevdp commented Aug 23, 2011

Interesting thought, but I'm not sure if that would lead to faster code. In the case of sorting distances and indices, for example, it seems faster to use the current simultaneous in-place quicksort of the two arrays, rather than using PyArray_Argsort to create an array of sort-indices, then using these indices to construct sorted copies of the two arrays, and then copying the results into the original arrays. Regarding searching and copying, etc. there may be some gain from using the more optimized routines in the numpy C-API, but you wouldn't get around the overhead of creating and destroying a PyArrayObject for each operation.

@ogrisel
Copy link
Member

ogrisel commented Aug 23, 2011

I think @fabianp wanted to emphasize code reuse (for easier maintenance) rather than performance. Indeed using the numpy C-API might induce a bit of overhead compared with your impl so I am ok not to reuse it in this case if you think this is not worth it.

@jakevdp
Copy link
Member Author

jakevdp commented Aug 23, 2011

Ah, I see. I also would put more stock in speed over code reuse in this situation. Fabian's suggestion leads to an interesting thought, though: the main reason I switched from numpy arrays to raw array pointers was the overhead of numpy array slicing. Using the numpy C-API directly to more quickly create array views/slices without the python overhead would have been another possible route. At this point, though, unless there's a very compelling reason, I'm not sure it's worth investing in a full rewrite of the module.

@fabianp
Copy link
Member

fabianp commented Aug 24, 2011

Indeed, at this point it would be too much work and the code is good enough.

I checked performance on some synthetic dataset, and get impressive improvements:

with this version

(p26)~/dev/scikit-learn(master) python benchmarks/bench_balltree.py 5000 10
---------------------------------------------------
20 neighbors of 5000 points in 10 dimensions:
   (leaf size = 20)
  -------------
  Ball Tree construction     : 0.0061 sec
  total (construction+query) : 1.07 sec
  -------------
  KD tree construction       : 0.00142 sec
  total (construction+query) : 4.18 sec
  -------------
  neighbors match:  True
  -------------

with the old one:

(p26)~/dev/scikit-learn(master) python benchmarks/bench_balltree.py 5000 10
---------------------------------------------------
20 neighbors of 5000 points in 10 dimensions:
   (leaf size = 20)
  -------------
  Ball Tree construction     : 0.0473 sec
  total (construction+query) : 1.79 sec
  -------------
  KD tree construction       : 0.00149 sec
  total (construction+query) : 4.09 sec
  -------------
  neighbors match:  True
  -------------

fabianp pushed a commit that referenced this pull request Aug 24, 2011
@fabianp fabianp merged commit 12e1c51 into scikit-learn:master Aug 24, 2011
@fabianp
Copy link
Member

fabianp commented Aug 24, 2011

Merged

@ogrisel
Copy link
Member

ogrisel commented Aug 24, 2011

\o/

@fabianp
Copy link
Member

fabianp commented Aug 24, 2011

However, I can still not pickelize the BallTree object, here is a test case:

https://github.com/scikit-learn/scikit-learn/compare/bt_pickle

@jakevdp
Copy link
Member Author

jakevdp commented Aug 24, 2011

It needs protocol=2 to pickle.

@ogrisel
Copy link
Member

ogrisel commented Aug 24, 2011

Why so? Is there no way to make it work with default protocol?

@jakevdp
Copy link
Member Author

jakevdp commented Aug 24, 2011

As far as I can tell, because cython classes don't expose their __dict__ attribute, the default protocol requires __reduce__, which calls the __init__ method. Since we don't want to re-build the BallTree after pickling, I instead defined __getstate__ and __setstate__, which only work with protocol 2.
We could hack it to get it to work with the default: it would require having some sort of keyword in the __init__ method which tells whether the function is getting initialized by the user or by the unpickler. If it's by the unpickler, the arguments would be the internal state of the pickled Ball Tree.
There may be a cleaner way to approach this. Any ideas?

@jakevdp
Copy link
Member Author

jakevdp commented Aug 24, 2011

Also, pickling seems not to work at all when BallTree is imported in certain ways, even with protocol 2! It's strange: the pickling doctest passed, but Gael's test doesn't. I'm not sure what's going on here. I seem to remember reading that pickling can behave strangely at times with nested imports. That may be part of the issue.

@jakevdp
Copy link
Member Author

jakevdp commented Aug 25, 2011

I've got pickling working with all protocols: the code is at http://github.com/jakevdp/pyTree
The problem is, when I move this version of ball_tree.pyx into the scikit-learn directory structure (see https://github.com/jakevdp/scikit-learn/tree/cython-ball-tree), pickle has trouble finding the paths it needs. I've spent a half hour tinkering and trying to figure out why it's not recognizing the namespace, but I have no idea.

@jakevdp
Copy link
Member Author

jakevdp commented Aug 25, 2011

I think I've found the source of the problem:

>>> from scikits.learn.neighbors import NeighborsClassifier, BallTree
>>> print NeighborsClassifier
<class 'scikits.learn.neighbors.NeighborsClassifier'>
>>> print BallTree
<type 'ball_tree.BallTree'>

The import isn't recognizing the full path to BallTree, which is what pickle uses to try to find it upon unpickling. This leads to the error

>>> X = np.random.random((10, 3))
>>> balltree = BallTree(X)
>>> s = pickle.dumps(balltree, protocol=2)
...
pickle.PicklingError: Can't pickle <type 'ball_tree.BallTree'>: it's not found as ball_tree.BallTree

This problem didn't come up in the pickling doc-test that I wrote, probably because that test happens at a local level. Any ideas about how to fix this?

@ogrisel
Copy link
Member

ogrisel commented Aug 26, 2011

Good catch! I wonder if the lack of introspectable path is related to the parent_package and top_path arguments of the build setup:

https://github.com/scikit-learn/scikit-learn/blob/master/scikits/learn/setup.py#L6

@jakevdp
Copy link
Member Author

jakevdp commented Aug 26, 2011

Could be... I'm going to be out camping until the middle of next week. I'll try to track down the problem once I'm back!

@ogrisel
Copy link
Member

ogrisel commented Aug 26, 2011

Ok. FYI I have tried to hack a self.__module__ = 'scikits.learn.ball_tree' in the constructor of the BallTree class but it seems to be ignored by the pickler that always looks for the ball_tree module name instead. Might be worth bringing the issue on the cython user mailing list.

@fabianp
Copy link
Member

fabianp commented Aug 26, 2011

I opened an issue to keep trac of this: #323

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants