Monday, May 19, 2014

Combining array bounds checking and loop condition checking into a single operation

In a memory-safe language, the naive way to loop over an array involves too many runtime checks:

int i = 0;
while (i < array.length) {
    doSomething(array[i]);
    i = i + 1;
}

The variable i is checked against array.length in the loop condition, and then checked against 0 and array.length again when array[i] is dereferenced. That's three comparisons for every element of the array. A clever compiler can remove some of them, but that requires complicated mathematical reasoning. For example, to remove the check against 0, the compiler must build a proof by mathematical induction that i always stays non-negative.

In this post I'll try to make it easier for the compiler to remove all checks except one. First, let's note that looping over a linked list, implemented as a sum type, requires only one check per element to be completely memory-safe:

data List a = Empty | Append a (List a)

product :: List Int -> Int
product (Empty) = 1
product (Append value tail) = value * sum tail

We can see that looping over a linked list requires only one check, which serves double purpose as a bounds check (for memory safety) and a loop condition check (for loop termination). Then how come looping over an array requires three checks? Can we reduce the number of checks without losing memory-safety?

The answer is that we can have only one check per element for arrays too, if we treat them like lists. We just need to use an Iterator class which is guaranteed to contain a valid index into the array, so that dereferencing the iterator doesn't require a bound check, and moving the iterator forward or backward requires exactly one bound check:

#include <iostream>

template <typename Type> class NonEmptyArray;  // defined below

template <typename Type> class NonEmptyIterator {
    friend class NonEmptyArray<Type>;

    private:
    Type* array;
    unsigned int maxIndex;
    unsigned int index;  // 0 <= index <= maxIndex
    NonEmptyIterator(Type* array_, unsigned int maxIndex_):
        array(array_), maxIndex(maxIndex_), index_(0) {}

    public:
    Type get() {
        return array[index];
    }
    
    void set(Type value) {
        array[index] = value;
    }

    bool increment() {
        if (index < maxIndex) {
            index++;
            return true;
        } else {
            return false;
        }
    } 

    bool decrement() {
        if (index > 0) {
            index--;
            return true;
        } else {
            return false;
        }
    } 
};

template <typename Type> class NonEmptyArray {
    private:  
    Type* array;
    unsigned int maxIndex;  // this means length - 1

    public:
    NonEmptyArray(unsigned int maxIndex_): maxIndex(maxIndex_) {
        array = new Type[maxIndex + 1];
    }
    
    ~NonEmptyArray() {
        delete [] array;
    }

    // Noncopyable boilerplate omitted...

    NonEmptyIterator<Type> iterator() {
        return NonEmptyIterator<Type>(array, maxIndex);
    }
};

int main() {
    NonEmptyArray<double> array(2);  // that means length is 3
    NonEmptyIterator<double> it(array.iterator());

    // Fill the array
    double x = 1.0;
    do {
        it.set(x);
        x += 1.0;
    } while (it.increment());

    // Compute the sum of the array elements
    double sum = 0.0;
    do {
        sum += it.get();
    } while (it.decrement());
    std::cout << sum << std::endl;  // prints 6
}

The above iterator class works only for non-empty arrays, because it needs to maintain the invariant that get() always returns a valid value. We could work around that by simulating sum types in C++, but I don't want to overcomplicate the code. The important thing is that once the iterator is constructed, calling the methods get(), set(), increment() and decrement() in any order is completely memory-safe for the lifetime of the array, while requiring only one bounds check for the typical case of traversing the array in a loop.

(Clever readers might point out that the check inside increment() and decrement() returns a boolean, which is then inspected in the loop condition, so that's technically two checks. But if the method calls are inlined, the compiler can trivially combine these two checks into one without doing any fancy mathematical reasoning.)

It's an interesting exercise to adapt the above approach to more complicated cases, like traversing two arrays of the same length at once. But I feel that's not really the point. Unlike approaches based on theorem provers or dependent types (or even "non-dependent types"), this approach to eliminating bounds checking probably doesn't scale to more complicated use cases. But it does have the advantage of being simple and understandable to most programmers, without using any type-theoretic tricks.

No comments:

Post a Comment