Crimson Channel

A Commentary on Things


Project maintained by Owen Jow Midnight theme by Matt Graham

Matrix Ordering

In the past, I remember being rather bemused by TensorFlow code which wrote linear transformations (matrix multiplications) as \(xW\). After all, in the single-vector case it’s always written as \(Wx\). How did we end up reversing the ordering of non-commutative matrix multiplication?

I’m embarrassed to admit I didn’t actually worry about this too much; it became no more than a nagging doubt in the back of my mind. Conceptually it was still the same two quantities being multiplied together, right? Somehow that was good enough for me. But of course this is no way to do things. And today I finally looked into it.

Suffice to say, it all makes sense now. It’s actually blindingly simple.

When it comes to data-driven algorithms, operations are often batched. In the TensorFlow examples I came across, \(x\) was referring to a matrix containing a batch of input vectors, not just a single input vector. Hence its shape would be \(n \times d\), where \(n\) is the batch size and \(d\) is the dimensionality of each individual vector.

In the single-vector case, \(W\) would be \(h \times d\) (and would be multiplied by a \(d \times 1\) vector to produce an \(h \times 1\) embedding). But in order to mesh with the batch formulation we really need to do \(xW^T\). I don’t always see the transpose because of the design of the weight variable, which was definitely a contributor to my confusion.

\(xW^T\) applies the weight matrix to a batch of input vectors, and keeps them in their batch form! Namely, it produces a batch of linearly transformed vectors as an \(n \times h\) matrix.

Example

From here on out, I’m going to write \(x\) as \(X\) when it’s a matrix. (People should always do this.)
If \(x_i\) is the \(i\)th input vector and \(x_{ij}\) is the \(j\)th component of the \(i\)th input vector, then we have

\[\begin{align} XW^T &= \begin{bmatrix} x_{11} & x_{12} & ... & x_{1d} \\ x_{21} & x_{22} & ... & x_{2d} \\ \vdots & \vdots & \vdots & \vdots \\ x_{n1} & x_{n2} & ... & x_{nd} \end{bmatrix} \begin{bmatrix} w_{11} & w_{21} & ... & w_{h1} \\ w_{12} & w_{22} & ... & w_{h2} \\ \vdots & \vdots & \vdots & \vdots \\ w_{1d} & w_{2d} & ... & w_{hd} \end{bmatrix} \\ &= \begin{bmatrix} w_1^Tx_1 & w_2^Tx_1 & ... & w_h^Tx_1 \\ w_1^Tx_2 & w_2^Tx_2 & ... & w_h^Tx_2 \\ \vdots & \vdots & \vdots & \vdots \\ w_1^Tx_n & w_2^Tx_n & ... & w_h^Tx_n \end{bmatrix} \end{align}\]

as compared to the single-vector case:

\[\begin{align} Wx &= \begin{bmatrix} w_{11} & w_{12} & ... & w_{1d} \\ w_{21} & w_{22} & ... & w_{2d} \\ \vdots & \vdots & \vdots & \vdots \\ w_{h1} & w_{h2} & ... & w_{hd} \end{bmatrix} \begin{bmatrix} x_1 \\ x_2 \\ \vdots \\ x_d \end{bmatrix} \\ &= \begin{bmatrix} w_1^Tx \\ w_2^Tx \\ \vdots \\ w_h^Tx \end{bmatrix} \end{align}\]

The \(i\)th row of \(XW^T\) is simply \(Wx_i\) (i.e. \(W\) applied to the \(i\)th vector in \(X\))!

Note: we could also perform the linear mapping as \(WX^T\), but then the first dimension wouldn’t correspond to the minibatch index. So that’s no good.

back