tldr: Parameter Server

Part of the tldr series, which summarizes papers I'm currently reading. Inspired by the morning paper.

Scaled Distributed Machine Learning with the Parameter Server [2014]

Li et. al 

Introduction

The motivation for this paper is efficiently solving large-scale machine learning problems by distributing the work over worker nodes. Fault tolerance is also a major goal of this paper: training tasks are often run in the cloud, which means having to deal with unreliable machines.

Before going into parameter servers, here's a quick primer on machine learning in relation to systems.

The goal of ML can be thought of as finding a model, which is just a function approximation. For example, the function could be f(user profile), which is equal to the likelihood distribution of that user clicking on an ad. Machine learning happens in two parts:

  • Training. Exposing the model to training data so the model can improve. This is an iterative process.
  • Inference. Testing the model on new data. 
Parameter servers are meant to handle training.

Furthermore, features are properties of the input/training data. Parameters quantify the importance of features (in other words, how important they are to correctly classify a photo, for example.) Each parameter is a weight, so it's only a few bytes, but there can be lots of them. Parameter servers are useful when our parameters form a vector on the scale of 1 billion elements.

So, we need distributed training when:

  • The training dataset is too large to fit on a single machine.
  • We want to get parallel speedup.
  • We want fault tolerance.

Architecture 

In a parameter server, the nodes are grouped into one server group and many worker groups. The figure below shows an example of a parameter server instance. 
Nodes in the server group are each responsible for some of the parameters. Specifically, they: 
  • Each maintain a partition of the globally shared parameters
  • Communicate with each other to replicate and/or migrate parameters for reliability and scaling
  • Maintain a consistent view of the server metadata (assignment of parameter partitions, etc) 
Nodes in the worker groups do the actual training. Specifically, they: 
  • Locally store a portion of the training data 
  • Communicate only with server nodes to update and retrieve shared parameters
A scheduler node in each worker group is responsible for assigning tasks to workers, keeping track of their progress, and rescheduling unfinished tasks.

The overall training process is:

  • Initially:
    • Initialize parameters (for example, set them all to 1) 
    • Push parameters to workers
  • Each training iteration:
    • Workers compute updates to parameters
    • Workers push parameter updates to the corresponding parameter servers 
    • Parameter servers update parameters using a function defined by the user. They may have to aggregate changes from multiple workers.

Data model 

Nodes share data as key-value pairs. For example, when solving a loss minimization problem, the nodes will pass around (feature ID, weight) pairs. The authors also assume that the key-value pairs are ordered by key to make linear algebra operations possible. Although it's useful for users to think of the parameters as a long vector, they may actually be stored in a hash table or some other representation.

Batching

To move around data, nodes can push its local parameters and pull parameters from other nodes. The system supports range-based pushes and pulls to send over whole ranges at a time in order to improve throughput:
  • w.push(R, dest) sends all entries in parameter w in a key range R to a destination (a specific node, or a node group) 
    • R can be a a single key, or a whole key range 
  • w.pull(R dest) reads all entries of w in a key range R from a destination

Consistent Hashing

The system uses consistent hashing to allow workers to locally calculate parameter locations. The parameter locations can be thought of as placed on a ring. Then the first hash function H(server) maps to the parameter server's location on the ring. A second hash function H'(key) maps to the key's location on the ring.

Thus, there's no central lookup table to find parameter locations. 

Relaxed Consistency

The system sometimes uses slightly stale parameters (up to 8-16 iterations late). This won't affect a model too much, although it may take more time to train. Moral of the story: inconsistency in ML is usually fine. 

Vector clocks are used as a synchronization mechanism. Each key range keeps a vector of timestamps for all the workers and servers. The timestamp represents the last time a machine saw an update for a key. Ranges again come in useful here because we don't need as much memory to store vector clocks. (Imagine storing a vector for each key!) 

Fault Tolerance

If a worker crashes, we lose the training data and some state stored locally on the worker. The authors recommend either restarting the worker or just dropping it. Dropping the task would have a negligible impact on the model because we only lose a very small portion of the training data. 

If a parameter server crashes, we lose all the parameters stored on that server. This is bad, so the authors make sure that parameters are replicated on multiple servers. (When a worker sends an update to a parameter server, the parameter server pushes the updated parameters to other servers before replying to the worker. This is similar to how writes work in GFS.)

Conclusion

Parameter servers are a widely used, highly influential design. TensorFlow and other frameworks are based on this design in order to get high performance while training.  


Popular posts from this blog

Space Race: a simple Rust game built on the Bevy framework

Building A Toy Language Interpreter in Go

Building a Toy Language Compiler in Go