Saturday, May 12, 2012

Even More Concurrency; Or, What You Can't Find By Searching Google Scholar

(Continuing the series of Schrodinger's Equation of Software, First Class Everything: Loops, Tail-Calls, and Continuations, and Automatic Concurrency)

      Our last version of cps_map_eval gave us concurrent evaluation of arguments, and without even needing any locks or other kinds of synchronization. That's pretty impressive... except that it only works as long as a) every argument returns at least once and b) no argument returns more than once before all have returned at least once.
      Those may not seem like very strict restrictions, except that the whole point of continuations is to break them. Consider this program:

((vau () % (+ (return 2) 2)))

      With the last version of cps_map_eval, it will cause an error, because when the first argument thread has terminated, it will not have filled in its slot- it jumped to a different continuation. The parent thread will try to evaluate (+ None 2) and throw an exception, when what we really wanted was for it to not run at all! A similar problem can be seen in this program:

(+  ((vau () % (par (return 1) (return 2))))
    (do-long-computation))

      In this case, no arguments fail to return. But the first argument activates its continuation twice, and from different threads. If both of those continuation calls execute before (do-long-computation) completes, the first one will fill in the argument slot and disappear as intended, but the second will go ahead and execute the function body with an incomplete argument list, because we have assumed erroneously that a continuation can only be called a second time after the function body has begun executing. This is false as soon as argument expressions can themselves contain multiple threads, and even more false if you allow mutation so that external threads might get hold of the continuation.
      I searched far and wide for existing research on combining continuations and concurrency. But it seems that on this particular topic, nothing exists. Fortress does not have continuations, and no experimental LISP dialect has automatic concurrency! There is a lot of interesting work on what continuations mean in concurrent systems, and different ways to deal with them when you have explicit language constructs for parallelism like spawn or pcall, but none of it is easily applicable to a situation where all of the parallelism is implicit. There are some very interesting ideas out there, like this work on subcontinuations and this on inter-thread communication with ports, which I may come back to in later posts, but for now it looks like I, too, can be on the cutting edge of functional programming research.
      What we want is not for all argument threads to terminate before executing the function body, but for all argument slots to be filled, whether it's by the same threads that were started for that purpose or not. In order to do that, we'll have to introduce a counter for the number of argument positions that's decremented whenever an argument slot is filled; when the counter reaches zero, then we can run the function body. A version of cps_map_eval that does that looks like this:

def cps_map_eval(k,v,*x):
    """
    Evaluates the elements of an argument list,
    creating continuations that will assign values
    to the correct indices in the evaluated list.
    """
    arglen = len(x)
    if arglen == 0: return k([])

    counter = [arglen]
    finished = Event()
    flock = Lock()
    argv = [None]*arglen

    def assign_val(i,val):
        argv[i] = val
        flock.acquire()
        counter[0] -= 1
        flock.release()
        if counter[0] == 0:
            finished.set()

    def reassign(i,val):
        finished.wait()
        new_argv = argv[:]
        new_argv[i] = val
        return k(new_argv)

    def arg_thread(i,ax):
        eval(ax,v,ArgK( lambda val: assign_val(i,val),
                        lambda val: reassign(i,val)))

    threads = [Thread(target=arg_thread,args=(i,ax))
                for i, ax in enumerate(x[:-1])]

    for t in threads: t.start()

    def arg_k(val):
        argv[-1] = val
        flock.acquire()
        counter[0] -= 1
        flock.release()
        if counter[0] == 0:
            finished.set()
        else:
            finished.wait()
        for t in threads: t.join()
        return k(argv)

    return Tail(x[-1],v,
                ArgK(arg_k,lambda val: reassign(-1,val)))

      Again, we're using a 1-element array because Python won't let us mutate variables in closure scopes, but will let us alter the contents of collections. Initial continuations fill in values and decrement the counter; non-initial continuations and the parent thread that will execute the function body wait until all of the slots are filled (every initial continuation has been called). When all of the slots are filled, any waiting threads are unblocked and the function body can be executed for as many times as continuations containing it have been activated. This almost works. Except for one rather important problem. This program will never terminate:

((vau () % (+ (return 1) 1)))

      The argument thread will send 1 to the outer continuation, never giving anything to (+ [] 1), and the parent thread will hang forever waiting for that slot to be filled. So, we can't allow the parent thread to block. To fix this, we don't privilege the parent thread. It becomes just another argument evaluator, whose sole speciality is the ability to call join() to cleanup all the other threads eventually. Instead, every initial continuation checks to see whether it filled in the last argument slot, and if so, it evaluates the function body. This means that the thread that began a computation may not be the same thread that completes it- the continuation can get picked up by a different thread, and the parent may end up swapped out for one of its children. This version of cps_map_eval takes care of that:

def cps_map_eval(k,v,*x):
    """
    Evaluates the elements of an argument list,
    creating continuations that will assign values
    to the correct indices in the evaluated list.
    """
    arglen = len(x)
    if arglen == 0: return k([])

    counter = [arglen]
    finished = Event()
    flock = Lock()
    argv = [None]*arglen

    def assign_val(i,val):
        argv[i] = val
        flock.acquire()
        counter[0] -= 1
        flock.release()
        if counter[0] == 0:
            finished.set()
            return k(argv)

    def reassign(i,val):
        finished.wait()
        new_argv = argv[:]
        new_argv[i] = val
        return k(new_argv)

    def arg_thread(i,ax):
        eval(ax,v,ArgK( lambda val: assign_val(i,val),
                        lambda val: reassign(i,val)))

    threads = [Thread(target=arg_thread,args=(i,ax))
                 for i, ax in enumerate(x[:-1])]

    for t in threads: t.start()

    def arg_k(val):
        r = assign_val(-1,val)
        for t in threads: t.join()
        return r

    return Tail(x[-1],v,
            ArgK(arg_k, lambda val: reassign(-1,val)))
      We're getting closer. All of the previous examples will now behave properly. But this one won't:

((vau () % (+ (return 1)
              ((vau () % (par (return 1)
                              (return 2)))))))

      The first argument slot is never filled. The first thread to return to the second argument slot fills in its value and terminates. And the second thread to return to the second argument slot then waits for the first one to be filled. And it will wait forever, that one blocked thread preventing the program from terminating.
      We could consider having some kind of thread garbage collector that periodically checks to see if all existing threads are blocked and if so terminates the program, or cleans up threads that can never be unblocked because the necessary continuations aren't held by anyone anymore. But what we really want is to simply not block in the first place. The choice of the name Event for the notification objects we've been using from Python's threading library so far is apt: we'd like something like an event loop that can store up computations and then execute them when an event (filling all the argument slots) occurs, or not if that event never occurs.
      We can achieve the same effect without actually building an event loop, though. Everything can still live nicely inside of cps_map_eval, and it looks like this:

def cps_map_eval(k,v,*x):
    """
    Evaluates the elements of an argument list,
    creating continuations that will assign values
    to the correct indices in the evaluated list.
    """
    arglen = len(x)
    if arglen == 0: return k([])

    counter = [arglen]
    flock = Lock()
    argv = [None]*arglen
    blocked = set()

    def assign_val(i,val):
        argv[i] = val
        flock.acquire()
        counter[0] -= 1
        if counter[0] == 0:
            flock.release()
            for t in blocked: t.start()
            r = k(argv)
            for t in blocked: t.join()
            blocked.clear()
            return r
        flock.release()

    def reassign(i,val):
        new_argv = argv[:]
        new_argv[i] = val
        return k(new_argv)

    def reactivate(i,val):
        flock.acquire()
        if counter[0] == 0:
            flock.release()
            return reassign(i,val)
        blocked.add(Thread(target=reassign,args=(i,val)))
        flock.release()

    def arg_thread(i,ax):
        eval(ax,v,ArgK( lambda val: assign_val(i,val),
                        lambda val: reactivate(i,val)))

    threads = [Thread(target=arg_thread,args=(i,ax))
                 for i, ax in enumerate(x[:-1])]

    for t in threads: t.start()

    def arg_k(val):
        r = assign_val(-1,val)
        for t in threads: t.join()
        return r

    return Tail(x[-1],v,
            ArgK(arg_k, lambda val: reactivate(-1,val)))

      If a continuation call cannot immediately continue, it's dumped into a blocked set and the thread terminates. If the argument slots are ever all filled, that set is checked, and everything in it is re-activated. If they aren't, the set just gets garbage collected when the necessary continuations are no longer held by anybody else. Note that thread join()s only occur after a parent thread has finished all possible useful work, and cannot do anything but return a final value to the top-level eval loop; thus, they do not inhibit concurrency. We could even get along without them, and let Python deal with the cleanup, but if we were using raw POSIX pthreads that wouldn't be an option, so we might as well cleanup our garbage. We did end up having to use 1 lock, to ensure atomic access to the counter, but that's pretty darn small; most of the time, it will never be contended. Note that we need to keep the lock around additions to the blocked set, just in case, but we don't need the lock when iterating over it; that's because we only get to that point if the counter has just been set to zero, and if the counter is zero, we know that no other thread will ever access the blocked set again.
      So there it is. Automatic concurrency that plays nice with continuations, without any need to modify the core vau evaluator.

As usual, working code can be found at https://github.com/gliese1337/schrodinger-lisp/.

Thursday, May 3, 2012

Automatic Concurrency

(Continuing the series of Schrodinger's Equation of Software and First Class Everything: Loops, Tail-Calls, and Continuations)

      So far, we've had a language with de-jure concurrent evaluation of lambda arguments, but de-facto left-to-right evaluation order. A modification to cps_map_eval can make our language genuinely concurrent.
      First, we'll drastically simplify the continuations of each argument evaluation, so they don't explicitly chain together anymore. Instead, we'll just evaluate them all in a loop. If the loop already finished, and all of the arguments have been previously evaluated, the continuation is the same as before: reassign the current argument and re-call the continuation of the whole argument list; otherwise, the continuation is implicitly the next iteration of the loop.

def cps_map_eval(k,v,*x):
    """
    Evaluates the elements of an argument list,
    creating continuations that will assign values
    to the correct indices in the evaluated list.
    """
    arglen = len(x)
    if arglen == 0: return k([])

    argv = [None]*arglen
    done = False
    
    def arg_thread(i,ax):
        def assign_val(val):
            if done:
                new_argv = argv[:]
                new_argv[i] = val
                return k(new_argv)
            else:
                argv[i] = val
        eval(ax,v,assign_val)

    for i, ax in enumerate(x):
        arg_thread(i,ax)
    
    done = True
    return k(argv)

      This is much simpler than our previous code, so why didn't we use it? Well, evaluating arguments in an explicit loop means that each call to eval is no longer a tail call. That's not a huge deal, because it will only grow the stack with recursive calls to eval for as many times as you type in nested function calls in argument positions in your source code; but we don't need to grow the stack at all, and doing so gains us nothing by itself since this still results in implicit left-to-right sequential evaluation. However, notice that the extra function used to create a closure for the value of the index for each argument evaluation is named arg_thread- with the code in this form, we can arrange to execute each argument evaluation in its own concurrent thread.

def cps_map_eval(k,v,*x):
    """
    Evaluates the elements of an argument list,
    creating continuations that will assign values
    to the correct indices in the evaluated list.
    """
    arglen = len(x)
    if arglen == 0: return k([])

    argv = [None]*arglen
    done = False
    
    def arg_thread(i,ax):
        def assign_val(val):
            if done:
                new_argv = argv[:]
                new_argv[i] = val
                return k(new_argv)
            else:
                argv[i] = val
        eval(ax,v,assign_val)

    threads = [Thread(target=arg_thread,args=(i,ax))
                for i, ax in enumerate(x)]

    for t in threads: t.start()
    for t in threads: t.join()
    
    done = True
    return k(argv)

      This isn't quite true parallelism because of Python's Global Interpreter Lock. But it is true concurrency, and with a better underlying implementation of threads would result in automatic paralellism. The existing thread starts up one new thread for each argument to be evaluated, waits for all of them to terminate, and then returns to the main eval loop. Each new thread makes one call to eval to start up its own evaluation loop. However, this implementation creates one more thread than necessary: the main thread sits idle while all of the arguments are evaluating. This is especially heinous if you only have 1 argument; why start a new thread just to do one sequential evaluation? That can be fixed by changing just a couple of lines so that the last argument in the list is evaluated in the current thread:

    threads = [Thread(target=arg_thread,args=(i,ax))
                for i, ax in enumerate(x[:-1])]

    for t in threads: t.start()
    arg_thread(arglen-1,x[-1]) #make use of the current thread
    for t in threads: t.join()

      But now we've re-introduced a recursive call to eval (inside of arg_thread)! Let's go ahead and CPS that away:

def cps_map_eval(k,v,*x):
    """
    Evaluates the elements of an argument list,
    creating continuations that will assign values
    to the correct indices in the evaluated list.
    """
    arglen = len(x)
    if arglen == 0: return k([])

    argv = [None]*arglen
    done = [False]
    
    def arg_thread(i,ax):
        def assign_val(val):
            if done[0]:
                new_argv = argv[:]
                new_argv[i] = val
                return k(new_argv)
            else:
                argv[i] = val
        eval(ax,v,assign_val)

    threads = [Thread(target=arg_thread,args=(i,ax))
                for i, ax in enumerate(x[:-1])]

    for t in threads: t.start()
    
    def arg_k(val):
        if done[0]:
            new_argv = argv[:]
            new_argv[-1] = val
            return k(new_argv)
        else:
            argv[-1] = val
            for t in threads: t.join()
            done[0] = True
            return k(argv)
    
    return Tail(x[-1],v,arg_k)
    
      This has some ugly code duplication. We can get rid of that and eliminate all of the conditional branches at the same time. The new version of cps_map_eval looks like this:

def cps_map_eval(k,v,*x):
    """
    Evaluates the elements of an argument list,
    creating continuations that will assign values
    to the correct indices in the evaluated list.
    """
    arglen = len(x)
    if arglen == 0: return k([])

    argv = [None]*arglen
    
    def assign_val(i,val):
        argv[i] = val

    def reassign(i,val):
        new_argv = argv[:]
        new_argv[i] = val
        return k(new_argv)

    def arg_thread(i,ax):
        eval(ax,v,ArgK( lambda val: assign_val(i,val),
                        lambda val: reassign(i,val)))

    threads = [Thread(target=arg_thread,args=(i,ax))
                for i, ax in enumerate(x[:-1])]

    for t in threads: t.start()
    
    def arg_k(val):
        argv[-1] = val
        for t in threads: t.join()
        return k(argv)

    return Tail(x[-1],v,
                ArgK(arg_k,lambda val: reassign(-1,val)))
    
      ArgK is a callable wrapper class that takes a function to run the first time it's called and a function to run every other time it's called. This eliminates a lot of nesting and allows reassign to be shared everywhere. The definition of ArgK looks like this:

### self-modifying continuations for argument evaluation

class ArgK():
    def __init__(self,first,rest):
        def k(val):
            self.k = rest
            return first(val)
        self.k = k

    def __call__(self,val):
        return self.k(val)
        
      And it contains no conditional branches.
      Now that we have a mechanism for concurrent evaluation of arguments, we can use that to build a concurrent analog to the sequential "begin" construction. The simplest way to do it is to simply pass concurrent expressions as arguments to "list"; a simple vau expression to do that and resturn the result of the first expression looks like this:

(define par (vau (a) % (car (eval % (cons list a)))))

      It's rather inefficient to create a list if you're going to throw most of it away, though. Just as we could make the built-in sequence function much simpler and more efficient than the sequential version of cps_map_eval, we can make a much better built-in concurrency construction. We'll have it parallel the semantics of "begin" by returning the value of the last listed expression.

def par(k,v,*x):
    """
    Evaluates arguments in parallel, returning the
    last listed value like sequence does.
    Ensures that all parallel threads terminate
    before continuing
    """
    if len(x) == 0: return k(None)
    final = [None]

    def call_k(val):
        return k(final[0])

    def par_thread(ax):
        eval(ax,v,ArgK(lambda val: None,call_k))
    
    threads = [Thread(target=par_thread,args=(ax,))
                for ax in x[:-1]]
    for t in threads: t.start()

    def par_k(val):
        for t in threads: t.join()
        final[0] = val
        return k(val)
    return Tail(x[-1],v,ArgK(par_k,call_k))
    
      The par function starts up a new thread with its own eval loop for every argument except the last, throws away the results of all those other threads, and saves the results of evaluating its last argument. And what if you don't want to throw away the results of all those other threads? Just call "list", and it will evaluate all of it's arguments concurrently! You can test it with this sample program which evaluates a bunch of expressions of varying complexity and prints the results in the order that they are completed:

(define mul (lambda (a b) (* a b)))
(begin
    (print (mul (+ 1 2) 3))
    (print (* 4 5))
    (print (+ 10 12))
    (print 7))
(par
    (print (mul (+ 1 2) 3))
    (print (* 4 5))
    (print (+ 10 12))
    (print 7))
(print
    (list
        (print (mul (+ 1 2) 3))
        (print (* 4 5))
        (print (+ 10 12))
        (print 7)))


As usual, working code can be found at https://github.com/gliese1337/schrodinger-lisp/.