Skip to content

Commit f48364e

Browse files
authored
interp: increase speed of findSegment
Using slices.BinarySearch instead of sort.Search increases the speed of findSegment by a factor of two and overall performance by about 30%. goos: linux goarch: amd64 pkg: gonum.org/v1/gonum/interp cpu: AMD Ryzen 7 5800 8-Core Processor │ old.bench │ new.bench │ │ sec/op │ sec/op vs base │ FindSegment-16 104.60n ± 1% 50.78n ± 1% -51.45% (p=0.000 n=10) NewPiecewiseLinear-16 114.5n ± 5% 112.2n ± 2% ~ (p=0.109 n=10) PiecewiseLinearPredict-16 116.00n ± 1% 84.44n ± 2% -27.21% (p=0.000 n=10) PiecewiseConstantPredict-16 87.95n ± 2% 63.93n ± 1% -27.31% (p=0.000 n=10) geomean 105.2n 74.47n -29.18%
1 parent 1dd194f commit f48364e

File tree

2 files changed

+29
-4
lines changed

2 files changed

+29
-4
lines changed

interp/interp.go

+7-4
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
package interp
66

7-
import "sort"
7+
import "slices"
88

99
const (
1010
differentLengths = "interp: input slices have different lengths"
@@ -156,10 +156,13 @@ func (pc PiecewiseConstant) Predict(x float64) float64 {
156156
}
157157

158158
// findSegment returns 0 <= i < len(xs) such that xs[i] <= x < xs[i + 1], where xs[len(xs)]
159-
// is assumed to be +Inf. If no such i is found, it returns -1. It assumes that len(xs) >= 2
160-
// without checking.
159+
// is assumed to be +Inf. If no such i is found, it returns -1.
161160
func findSegment(xs []float64, x float64) int {
162-
return sort.Search(len(xs), func(i int) bool { return xs[i] > x }) - 1
161+
i, found := slices.BinarySearch(xs, x)
162+
if !found {
163+
return i - 1
164+
}
165+
return i
163166
}
164167

165168
// calculateSlopes calculates slopes (ys[i+1] - ys[i]) / (xs[i+1] - xs[i]).

interp/interp_test.go

+22
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,28 @@ func TestFindSegment(t *testing.T) {
5050
}
5151
}
5252

53+
func TestFindSegmentEdgeCases(t *testing.T) {
54+
t.Parallel()
55+
56+
cases := []struct {
57+
xs []float64
58+
x float64
59+
want int
60+
}{
61+
{xs: nil, x: 0, want: -1},
62+
{xs: []float64{0}, x: -1, want: -1},
63+
{xs: []float64{0}, x: 0, want: 0},
64+
{xs: []float64{0}, x: 1, want: 0},
65+
}
66+
67+
for _, test := range cases {
68+
if got := findSegment(test.xs, test.x); got != test.want {
69+
t.Errorf("unexpected value of findSegment(%v, %f): got %d want: %d",
70+
test.xs, test.x, got, test.want)
71+
}
72+
}
73+
}
74+
5375
func BenchmarkFindSegment(b *testing.B) {
5476
xs := []float64{0, 1.5, 3, 4.5, 6, 7.5, 9, 12, 13.5, 16.5}
5577
for i := 0; i < b.N; i++ {

0 commit comments

Comments
 (0)