Thursday, May 3, 2012

Automatic Concurrency

(Continuing the series of Schrodinger's Equation of Software and First Class Everything: Loops, Tail-Calls, and Continuations)

      So far, we've had a language with de-jure concurrent evaluation of lambda arguments, but de-facto left-to-right evaluation order. A modification to cps_map_eval can make our language genuinely concurrent.
      First, we'll drastically simplify the continuations of each argument evaluation, so they don't explicitly chain together anymore. Instead, we'll just evaluate them all in a loop. If the loop already finished, and all of the arguments have been previously evaluated, the continuation is the same as before: reassign the current argument and re-call the continuation of the whole argument list; otherwise, the continuation is implicitly the next iteration of the loop.

def cps_map_eval(k,v,*x):
    """
    Evaluates the elements of an argument list,
    creating continuations that will assign values
    to the correct indices in the evaluated list.
    """
    arglen = len(x)
    if arglen == 0: return k([])

    argv = [None]*arglen
    done = False
    
    def arg_thread(i,ax):
        def assign_val(val):
            if done:
                new_argv = argv[:]
                new_argv[i] = val
                return k(new_argv)
            else:
                argv[i] = val
        eval(ax,v,assign_val)

    for i, ax in enumerate(x):
        arg_thread(i,ax)
    
    done = True
    return k(argv)

      This is much simpler than our previous code, so why didn't we use it? Well, evaluating arguments in an explicit loop means that each call to eval is no longer a tail call. That's not a huge deal, because it will only grow the stack with recursive calls to eval for as many times as you type in nested function calls in argument positions in your source code; but we don't need to grow the stack at all, and doing so gains us nothing by itself since this still results in implicit left-to-right sequential evaluation. However, notice that the extra function used to create a closure for the value of the index for each argument evaluation is named arg_thread- with the code in this form, we can arrange to execute each argument evaluation in its own concurrent thread.

def cps_map_eval(k,v,*x):
    """
    Evaluates the elements of an argument list,
    creating continuations that will assign values
    to the correct indices in the evaluated list.
    """
    arglen = len(x)
    if arglen == 0: return k([])

    argv = [None]*arglen
    done = False
    
    def arg_thread(i,ax):
        def assign_val(val):
            if done:
                new_argv = argv[:]
                new_argv[i] = val
                return k(new_argv)
            else:
                argv[i] = val
        eval(ax,v,assign_val)

    threads = [Thread(target=arg_thread,args=(i,ax))
                for i, ax in enumerate(x)]

    for t in threads: t.start()
    for t in threads: t.join()
    
    done = True
    return k(argv)

      This isn't quite true parallelism because of Python's Global Interpreter Lock. But it is true concurrency, and with a better underlying implementation of threads would result in automatic paralellism. The existing thread starts up one new thread for each argument to be evaluated, waits for all of them to terminate, and then returns to the main eval loop. Each new thread makes one call to eval to start up its own evaluation loop. However, this implementation creates one more thread than necessary: the main thread sits idle while all of the arguments are evaluating. This is especially heinous if you only have 1 argument; why start a new thread just to do one sequential evaluation? That can be fixed by changing just a couple of lines so that the last argument in the list is evaluated in the current thread:

    threads = [Thread(target=arg_thread,args=(i,ax))
                for i, ax in enumerate(x[:-1])]

    for t in threads: t.start()
    arg_thread(arglen-1,x[-1]) #make use of the current thread
    for t in threads: t.join()

      But now we've re-introduced a recursive call to eval (inside of arg_thread)! Let's go ahead and CPS that away:

def cps_map_eval(k,v,*x):
    """
    Evaluates the elements of an argument list,
    creating continuations that will assign values
    to the correct indices in the evaluated list.
    """
    arglen = len(x)
    if arglen == 0: return k([])

    argv = [None]*arglen
    done = [False]
    
    def arg_thread(i,ax):
        def assign_val(val):
            if done[0]:
                new_argv = argv[:]
                new_argv[i] = val
                return k(new_argv)
            else:
                argv[i] = val
        eval(ax,v,assign_val)

    threads = [Thread(target=arg_thread,args=(i,ax))
                for i, ax in enumerate(x[:-1])]

    for t in threads: t.start()
    
    def arg_k(val):
        if done[0]:
            new_argv = argv[:]
            new_argv[-1] = val
            return k(new_argv)
        else:
            argv[-1] = val
            for t in threads: t.join()
            done[0] = True
            return k(argv)
    
    return Tail(x[-1],v,arg_k)
    
      This has some ugly code duplication. We can get rid of that and eliminate all of the conditional branches at the same time. The new version of cps_map_eval looks like this:

def cps_map_eval(k,v,*x):
    """
    Evaluates the elements of an argument list,
    creating continuations that will assign values
    to the correct indices in the evaluated list.
    """
    arglen = len(x)
    if arglen == 0: return k([])

    argv = [None]*arglen
    
    def assign_val(i,val):
        argv[i] = val

    def reassign(i,val):
        new_argv = argv[:]
        new_argv[i] = val
        return k(new_argv)

    def arg_thread(i,ax):
        eval(ax,v,ArgK( lambda val: assign_val(i,val),
                        lambda val: reassign(i,val)))

    threads = [Thread(target=arg_thread,args=(i,ax))
                for i, ax in enumerate(x[:-1])]

    for t in threads: t.start()
    
    def arg_k(val):
        argv[-1] = val
        for t in threads: t.join()
        return k(argv)

    return Tail(x[-1],v,
                ArgK(arg_k,lambda val: reassign(-1,val)))
    
      ArgK is a callable wrapper class that takes a function to run the first time it's called and a function to run every other time it's called. This eliminates a lot of nesting and allows reassign to be shared everywhere. The definition of ArgK looks like this:

### self-modifying continuations for argument evaluation

class ArgK():
    def __init__(self,first,rest):
        def k(val):
            self.k = rest
            return first(val)
        self.k = k

    def __call__(self,val):
        return self.k(val)
        
      And it contains no conditional branches.
      Now that we have a mechanism for concurrent evaluation of arguments, we can use that to build a concurrent analog to the sequential "begin" construction. The simplest way to do it is to simply pass concurrent expressions as arguments to "list"; a simple vau expression to do that and resturn the result of the first expression looks like this:

(define par (vau (a) % (car (eval % (cons list a)))))

      It's rather inefficient to create a list if you're going to throw most of it away, though. Just as we could make the built-in sequence function much simpler and more efficient than the sequential version of cps_map_eval, we can make a much better built-in concurrency construction. We'll have it parallel the semantics of "begin" by returning the value of the last listed expression.

def par(k,v,*x):
    """
    Evaluates arguments in parallel, returning the
    last listed value like sequence does.
    Ensures that all parallel threads terminate
    before continuing
    """
    if len(x) == 0: return k(None)
    final = [None]

    def call_k(val):
        return k(final[0])

    def par_thread(ax):
        eval(ax,v,ArgK(lambda val: None,call_k))
    
    threads = [Thread(target=par_thread,args=(ax,))
                for ax in x[:-1]]
    for t in threads: t.start()

    def par_k(val):
        for t in threads: t.join()
        final[0] = val
        return k(val)
    return Tail(x[-1],v,ArgK(par_k,call_k))
    
      The par function starts up a new thread with its own eval loop for every argument except the last, throws away the results of all those other threads, and saves the results of evaluating its last argument. And what if you don't want to throw away the results of all those other threads? Just call "list", and it will evaluate all of it's arguments concurrently! You can test it with this sample program which evaluates a bunch of expressions of varying complexity and prints the results in the order that they are completed:

(define mul (lambda (a b) (* a b)))
(begin
    (print (mul (+ 1 2) 3))
    (print (* 4 5))
    (print (+ 10 12))
    (print 7))
(par
    (print (mul (+ 1 2) 3))
    (print (* 4 5))
    (print (+ 10 12))
    (print 7))
(print
    (list
        (print (mul (+ 1 2) 3))
        (print (* 4 5))
        (print (+ 10 12))
        (print 7)))


As usual, working code can be found at https://github.com/gliese1337/schrodinger-lisp/.