Alex Rodrigues

Data science, distributed systems and big data

Computing Approximate Histograms in Parallel

Today I’m going to write a little about Approximate Histograms and how can they be used to get more insight on streamed big data feeds. I also provide a simple Java implementation and explain some parts of it.

Most of the common aggregation operations like counting and summing can be performed in parallel, as long there is a reduce phase where the result on each node can be combined. However, this is not very trivial for calculating histograms, as we need all the data on one dimension so that we can represent it in an histogram.

Having the data being processed by multiple nodes, each node is only able to construct an histogram of the partial data it receives. Ben-Haim and Tom-Tov presented a solution that uses an heap-based data structure to represent the data and a merge algorithm that allows to merge the data structures computed on different nodes into one that is an approximate histogram of all the dataset.

This technique has been applied by MetaMarkets with good accuracy for most of what an histogram can tell us about the data distribution: calculating the average and counting the quartiles and total number of data/events.

I took the liberty of doing a simple implementation of it, that is now being used in production for some months now:

ApproximateHistogram.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;

import java.util.Iterator;
import java.util.Set;
import java.util.TreeSet;

/**
 * This is an approximate histogram class
 */
public class ApproximateHistogram {
    private final int numPairs;
    private final TreeSet<CentroidPair> heap;

    public ApproximateHistogram(int numPairs) {
        this.numPairs = numPairs;
        this.heap = new TreeSet();
    }

    public ApproximateHistogram(TreeSet<CentroidPair> heap, int numPairs) {
        this.numPairs = numPairs;
        this.heap = heap;
    }

    public Set<CentroidPair> heap() {
        return ImmutableSet.copyOf(heap);
    }

    public void update(CentroidPair p) {
        Iterator<CentroidPair> it = heap.iterator();
        while (it.hasNext()) {
            CentroidPair cp = it.next();

            int compare = Double.compare(cp.centroid, p.centroid);
            if (compare == 0) {
                cp.count += p.count;
                return;
            } else if (compare == 1) {
                break;
            }
        }

        // there was no similar centroid, so let's add the point to the heap
        heap.add(p);

        compress();
    }

    private void compress() {
        if (heap.size() <= numPairs) {
            return; // compress only if needed
        }

        int i = 0;
        double minDiff = Double.MAX_VALUE;
        CentroidPair last = null, lastLast = null;

        // [ ..., minA, minB, ... ] two consecutive pairs which centroid diff is the minimum
        CentroidPair minA = null, minB = null;

        Iterator<CentroidPair> it = heap.iterator();
        while (it.hasNext()) {
            lastLast = last;
            last = it.next();

            if (i > 0) {
                double diff = last.centroid - lastLast.centroid;

                if (diff < minDiff) {
                    minA = lastLast;
                    minB = last;
                    minDiff = diff;
                }
            }
            ++i;
        }

        int repCount = Math.abs(minA.count) + Math.abs(minB.count);
        double repCentroid = (minA.centroid * Math.abs(minA.count) + minB.centroid * Math.abs(minB.count)) / repCount;
        CentroidPair replacementPair = new CentroidPair(-repCount, repCentroid); // store with negative sign the compressed entries
        heap.remove(minA);
        heap.remove(minB);
        heap.add(replacementPair);
    }


    public static ApproximateHistogram merge(ApproximateHistogram... histograms) {
        ApproximateHistogram merged = histograms[0];

        for (int i = 1; i < histograms.length; i++) {
            merged = merge(merged, histograms[i]);
        }

        return merged;
    }

    public static ApproximateHistogram merge(ApproximateHistogram a, ApproximateHistogram b) {
        int biggestSize = a.heap.size();

        if (b.heap.size() > biggestSize) {
            biggestSize = b.heap.size();
        }

        TreeSet<CentroidPair> mergedHeap = Sets.newTreeSet();
        mergedHeap.addAll(a.heap);
        mergedHeap.addAll(b.heap);

        final ApproximateHistogram merged = new ApproximateHistogram(mergedHeap, biggestSize);

        // add the centroids of B to the merged (ignoring compression)
        int compressTimes = mergedHeap.size() - biggestSize;

        while (compressTimes-- > 0) {
            merged.compress();
        }

        return merged;
    }


    public double countBelow(double cutPoint) {
        final double EPSILON = 0.00000001;
        if (heap.isEmpty()) return 0.0;


        CentroidPair[] heapPoints = heap.toArray(new CentroidPair[heap.size()]);

        int j = 0;
        for (int i = 0; i < heapPoints.length; i++) {
            int count = heapPoints[i].count;
            double diff = heapPoints[i].centroid - cutPoint;

            // there's a pair with the cutPoint as centroid
            if (Math.abs(diff) < EPSILON) {
                return j + ((count > 0) ? count : Math.abs(count) / 2.0);
            } else if (diff > 0) {
                // we already passed. it's somewhere between the last and this one

                // CASE: the cutPoint is before the first centroid point
                if (i == 0) {
                    if (count > 0) return 0.0; // we are sure no entry was less than the first centroid
                    return Math.abs(count) * cutPoint / (2.0 * heapPoints[i].centroid);// the first pair is an average. do the calculation
                }

                CentroidPair lastPoint = heapPoints[i - 1];
                CentroidPair currentPoint = heapPoints[i];
                int lastCount = lastPoint.count;

                // if the last point is just an average point, discount it
                j -= ((lastCount < 0) ? Math.abs(lastCount) : 0); // WHT?

                lastCount = Math.abs(lastCount);

                double mb = lastCount + (Math.abs(currentPoint.count) - lastCount) * (cutPoint - lastPoint.centroid) / (currentPoint.centroid - lastPoint.centroid);
                double sum = (lastCount + mb) * (cutPoint - lastPoint.centroid) / (2.0 * (currentPoint.centroid - lastPoint.centroid));

                return sum + j + Math.abs(lastCount) / 2.0;
            }

            j += Math.abs(count);
        }

        // some logic for the cases where b > centroid[last]
        CentroidPair lastPoint = heapPoints[heapPoints.length - 1];
        int count = lastPoint.count;

        // last point is an average and there's more than one
        if (count < 0 && heapPoints.length > 1) {
            count = Math.abs(count);
            CentroidPair lastLastPoint = heapPoints[heapPoints.length - 2];

            // calculate a virtual final point which is separated half the distance than the last one
            double distanceToPreviousPoint = lastPoint.centroid - lastLastPoint.centroid;
            distanceToPreviousPoint /= 4.0;
            double finalCentroid = lastPoint.centroid + distanceToPreviousPoint;

            double diff = finalCentroid - cutPoint;

            // count all!
            if (diff > 0) {
                j -= count / 2.0;

                double trapezoidSum = count * (cutPoint - lastPoint.centroid) / (2.0 * distanceToPreviousPoint);

                return j + trapezoidSum;
            }
            // else return j!
        }


        return j;
    }

    public double avg() {
        int count = 0;
        double sum = 0.0;

        for (CentroidPair centroidPair : heap) {
            int absCount = Math.abs(centroidPair.count);
            count += absCount;
            sum += absCount * centroidPair.centroid;
        }

        return (count > 0) ? sum / count : 0.0;
    }

    public int count() {
        int sum = 0;
        for (CentroidPair centroidPair : heap) {
            sum += centroidPair.count;
        }
        return sum;
    }

    public static class CentroidPair implements Comparable<CentroidPair> {
        int count;
        double centroid;

        public CentroidPair(int count, double centroid) {
            this.count = count;
            this.centroid = centroid;
        }

        @Override
        public int compareTo(CentroidPair o) {
            return Double.compare(this.centroid, o.centroid);
        }

        @Override
        public String toString() {
            return new StringBuilder("(").append(count).append(", ").append(centroid).append(")").toString();
        }
    }
}

Internally the histogram is represented by a set of points (count, centroid), ordered by it’s centroid. When a new point is added, if the centroid already exists we increase the count number, otherwise we add the point with count 1 to the list.

Each histogram has a limit of points to keep and when a new insert exceeds this limit, a compression takes place. The compression consists in merging the two consecutive points where the difference between its centroids is the lower. The two are replaced by a single point with centroid on a place nearer to the neighbor point that has more counts: if they have the same count, it would be on the middle. The count of the new point will be the sum of the two old ones.

As Java doesn’t have unsigned numeric types, this implementation exploits the signal in the count field to flag if that point has been originated from compression of two other or if it is from raw observations. This can help answering to questions like: how many values are below X? If the points have a positive count for every point whose centroid is below X, we can truly count them. If they are negative, we know that point is an approximation, so we calculate the count using the trapezoidal estimation of Ben-Haim and Tom-Tov. This gives more accurate results than assuming every point might be an approximation and requires no extra space in Java-based data structures.

For merging more than one histogram, which happens when we want to combine results computed on different nodes. This is done by creating a big heap with the combined values of the histograms and applying compression on that heap, as described above, until the heap has the maximum number of points.

For very disperse data, this data structure may yield bad approximations if the number of points is not high enough. This data structure is very flexible and it’s easy to use it for streams with different distributions by just tuning the number of centroids we keep.

Comments