Andrzej unjello Lichnerowicz

Multiprocessing in Jupyter Notebooks

2024-10-13T20:01:33+02:00

TL;DR

To run heavy workloads in a Jupyter Notebook, you need to take your function and store it in a temporary Python source file so it’s importable as a module:

import contextlib
import inspect
import importlib
import pathlib
import sys
import tempfile


@contextlib.contextmanager
def wrap_fn(fn):
    importlib.invalidate_caches()
    source = inspect.getsource(fn)
    with tempfile.NamedTemporaryFile(suffix=".py", mode="w", encoding="utf-8") as f:
        f.write(source)
        f.flush()
        path = pathlib.Path(f.name)
        directory = str(path.parent)
        sys.path.append(directory)
        module = importlib.import_module(str(path.stem))
        yield getattr(module, func.__name__)
        sys.path.remove(directory)


def square(x):
    return x*x


context = multiprocessing.get_context("spawn")
with ProcessPoolExecutor(max_workers=12, mp_context=context) as executor, wrap_fn(square) as fn:
    for sq in executor.map(fn, range(0, 10))
        print(f"{sq} ", end="")
print("")

Here’s what’s happening in this code:

  • The wrap_fn context manager takes your function fn.
  • It gets the source code of fn and writes it to a temporary Python file.
  • It then imports this temporary file as a module, making your function accessible via a module name.
  • Inside the context, fn is now the function from the temporary module and can be pickled and sent to child processes.
  • We use ProcessPoolExecutor to distribute the workload across multiple processes.

By doing this, we avoid the issue of the function not being picklable due to being defined in the interactive namespace.

But if you ask me on a deeper level…

Jupyter Notebooks offer a great way to work with many things, combining the ability to play around with data like in a REPL with the maintainability of a script. Once in a while, though, you face a problem that requires some serious raw power. For me, such a moment occurred when I was working on a data matching problem in a dataset containing almost \(900,000\) elements. Performing naive \(A \times B\) pair matching with similarity tests over a couple of parameters that take 1 second would have me processing the set for… about 25,684 years. A gentle reminder that:

Big Oh N squared is bad, and you should feel bad -- a Zoidberg meme, very mature...

Very mature meme

Original idea captioned in memegenerator

Multiprocessing

So what do you do when you have so much crunching to do that you want to utilize all CPU cores? That’s when ProcessPoolExecutor steps in to save the day. For those unfamiliar, ProcessPoolExecutor is like that friend who helps you move heavy furniture; it lives in the concurrent.futures package of the standard Python library and, much like ThreadPoolExecutor, lets you distribute workload across a pool of workers. The difference is, these workers are separate processes, not threads.

With both threads and processes, the API looks innocently simple: you just pass the function that will be called by each worker with a piece of data:

import concurrent.futures

with concurrent.futures.ThreadsPoolExecutor(max_workers=20) as executor:
    squares = executor.map(lambda x: x*x, range(0, 100000)):

For threads, that’s pretty much it—you can pat yourself on the back and watch your CPU stay underutilized. But threads aren’t going to use all your cores due to the infamous Global Interpreter Lock (GIL). So what if we swap ThreadPoolExecutor for ProcessPoolExecutor to really get the party started? We get an exception immediately.

_pickle.PicklingError: Can't pickle <function <lambda> at 0x10332e160>: 
  attribute lookup <lambda> on __main__ failed

What happened? Well, the first rule of multiprocessing is that it uses pickle to serialize a function and data between the parent process and child processes. And pickle’s documentation says:

Note that functions (built-in and user-defined) are pickled by fully qualified name, not by value. This means that only the function name is pickled, along with the name of the containing module and classes. Neither the function’s code, nor any of its function attributes are pickled.

Since all lambdas share the same name – <lambda> – that explains why we need to define a proper method:

import concurrent.futures

def square(x):
    return x*x

with concurrent.futures.ThreadsPoolExecutor(max_workers=20) as executor:
    squares = executor.map(square, range(0, 100000)):

Now it works, but that error above has one little, less obvious consequence. If we only send a fully qualified name over to another process, then how does the method get called?

When the multiprocessing module starts a new child process, it doesn’t have access to your current interactive namespace. It’s like sending someone to fetch an item from your house without giving them the address—they need to know where to look.

Under the hood, when multiprocessing starts a new child process, it uses spawn on Windows and macOS. On POSIX systems, it’s still using fork, although since Python 3.12 it’ll throw a DeprecationWarning. Fork is a bit problematic from a safety point of view, as the child process inherits everything from the parent process. Spawn, on the other hand, starts a completely fresh Python interpreter process.

This means that any variables or functions you’ve defined interactively in your Jupyter Notebook aren’t automatically available in the child processes. They’re starting from scratch, and unless your function is defined in a module they can import, they’re left in the dark.

Here’s a rough diagram, courtesy of the Python docs, of how the flow looks when you submit work to the process pool:

= = P E = r P x = o o e = c o c = e l u = s t = s o = r = = = = = = = = = = W 6 7 W 6 = o o : = r r = k k c f = a u I I l t I d t l u n s e ( r - m ) e p s r o c e s s = = = L W T = o o h = c r r = a k e = l e a = r d = = = = = = = = = = 5 4 3 = , R Q , , = C Q e u = a u c s e r e = l e a u u e x | l u l l e s c = e l t u e = ( l p ) t t O u t - o f - p r o P P c r P r # e o o o 1 s c o c . s e l e . s s n = s s =

When you call a map, it internally calls submit, which converts your function, its parameters, and the data into a structure called WorkItem and adds it to a work_ids_queue. In a run-loop, it then pops those work items, and if they’re not canceled, it repackages them into a structure CallItem and puts it on a call_queue. The multiprocessing.Queue handles pickle serialization before writing the data. It’s up to pickle to get the fully qualified name of the function being serialized, and while doing that (and also when deserializing), pickle actually tries to import the module a function is part of:

if name is None:
    name = getattr(obj, '__qualname__', None)
if name is None:
    name = obj.__name__

module_name = whichmodule(obj, name)
try:
    __import__(module_name, level=0)
    module = sys.modules[module_name]
    obj2, parent = _getattribute(module, name)
except (ImportError, KeyError, AttributeError):
    raise PicklingError(
        "Can't pickle %r: it's not found as %s.%s" %
        (obj, module_name, name)) from None

That brings us to a very important requirement which is:

IMPORTANT

A function that we pass to executor must be visible via module.name

When you submit a function to ProcessPoolExecutor, the function and its arguments are pickled to be sent to the worker process. The pickling process relies on the function being defined at the module level with a name accessible via importing.

Enter Jupyter Notebook

The problem is, Jupyter Notebooks, aside from being a JSON document, execute code dynamically. Let’s see how it works:

= = F = r = o = n = t = e = n = d = = J u p y t e r = = = e r = x e = e p = c l = Ø u y | M t _ = Q e c = _ o = c n = o t = d e = e n = t = K e r n e l = = = d = I K o = P e _ = y r e = t n x = h e e = o l c = n u = t = e = = = = = = = = = = I S n h r I t e u P e l n y r l _ t a c h c e o t l n i l v = e =

The frontend sends the cell code as text to the kernel:

do_execute_args = {
    "code": code,
    "silent": silent,
    "store_history": store_history,
    "user_expressions": user_expressions,
    "allow_stdin": allow_stdin,
}

if self._do_exec_accepted_params["cell_meta"]:
    do_execute_args["cell_meta"] = cell_meta
if self._do_exec_accepted_params["cell_id"]:
    do_execute_args["cell_id"] = cell_id

# Call do_execute with the appropriate arguments
reply_content = self.do_execute(**do_execute_args)

The kernel then decides how to interpret the text of the cell and what to do with it. The IPython Kernel, as the name suggests, runs the IPython interactive shell, so it sends the contents of the cell to the shell to run it:

if hasattr(shell, "run_cell_async") and hasattr(shell, "should_run_async"):
    run_cell = shell.run_cell_async
    should_run_async = shell.should_run_async
    with_cell_id = _accepts_parameters(run_cell, ["cell_id"])
else:
    should_run_async = lambda cell: False  # noqa: ARG005, E731
    # older IPython,
    # use blocking run_cell and wrap it in coroutine

    async def run_cell(*args, **kwargs):
        return shell.run_cell(*args, **kwargs)

Finally, IPython runs a separate session, transforming the cell’s text into an Abstract Syntax Tree (AST), compiling, and running it:

async def run_cell_async(
        self,
        raw_cell: str,
        store_history=False,
        silent=False,
        shell_futures=True,
        *,
        transformed_cell: Optional[str] = None,
        preprocessing_exc_tuple: Optional[AnyType] = None,
        cell_id=None,
    ) -> ExecutionResult:
// ...
code_ast = compiler.ast_parse(cell, filename=cell_name)
// ...
code_ast = self.transform_ast(code_ast)
// ...
has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
                       interactivity=interactivity, 
                       compiler=compiler, result=result)

self.last_execution_succeeded = not has_raised
self.last_execution_result = result

As you can see, the whole thing does not look like something that can be imported as a module. The code exists only in the interactive namespace of the IPython shell. Hence, in order to be able to run ProcessPoolExecutor, we must take our function, store it in a separate file—preferably somewhere in a temp folder, as a temporary module that can be imported. This way we make it accessible to child processes started by ProcessPoolExecutor, so when pickle tries to serialize and deserialize our function, it can successfully import it by name, and everything works smoothly.

Comments

Discussion powered by , hop in. if you want.