There’s nothing particularly complex about this, but a few things surprised me along the way so I figured I’d write up some notes.

To generate a random $n$-dimensional unit vector, first generate a vector where each entry is a random sample from a normal distribution then normalize it to a unit vector. Why does this work? I’ll paraphrase a very helpful comment by mindoftea on StackOverflow:

The probability of a point being at a given $[x, y]$ is $P(x) \times P(y)$. The Gaussian distribution has roughly the form $\exp(-x^2)$, so $\exp(-x^2) \times \exp(-y^2)$ is $$\exp(-(x^2+y^2))$$ That is a function only of the distance of the point from the origin, so the resulting distribution is radially symmetric. This generalizes easily to higher dimensions.

Here’s a “visual” version of the algorithm for people familiar with computer graphics, inspired by a friend’s comments:

Observe that an $n$-dimensional Gaussian is radially symmetric around the origin (consider a Bell curve or a Gaussian splat). The radial symmetry means that if you squeeze the probability density function (pdf) onto the unit $n$-sphere you’ll end up with a uniform density. Just make sure to only move probabilities directly towards or away from the origin, which corresponds to only scaling points by scalars.

To generate the $n$-dimensional vector, note that Gaussians are separable, i.e. the $n$-dimensional Gaussian’s pdf is the product of $n$ independent $1$-dimensional Gaussian pdfs.

Here’s an Elixir Nx function that implements the algorithm.

defmodule Math do

  import Nx.Defn

  defn random_unit_vector(key, opts) do
    import Nx
    alias Nx.{LinAlg, Random}

    case Nx.shape(key) do
      {_} -> :ok
      _ -> raise "invalid shape"
    end

    dim = opts[:dim]
    vectorized_axes = key.vectorized_axes

    key_split = Random.split(key, parts: dim)

    axis = :random_unit_vector_key
    mean = 0
    stdev = 1

    v =
      key_split
      |> vectorize(axis)
      |> Random.normal_split(mean, stdev)
      |> revectorize(vectorized_axes, target_shape: {dim}, target_names: [axis])

    n = LinAlg.norm(v, axes: [axis], ord: 2)

    u =
      (v / n)
      |> Nx.rename([nil])

    {u, key_split[0]}
  end
end

You use it like so:

key = Nx.Random.key(37)

# One random 3-dimensional unit vector
{v, key} = Math.random_unit_vector(key, dim: 3)

# 8 random 2-dimensional unit vectors, vectorized along the
# along the :key axis
{vs, key} =
  Nx.Random.split(key, parts: 8)
  |> Nx.vectorize(:key)
  |> Math.random_unit_vector(dim: 2)

I ran into a few gotchas while writing this.

Nx.Random.split and number arguments

First I hit an error changing def to defn. It turns out that that’s because Nx.Random.split expects the value for parts to be a regular Elixir/BEAM number, not a tensor. But defn automatically converts all of its number arguments to tensors:

When numbers are given as arguments, they are always immediately converted to tensors on invocation. If you want to keep numbers as is or if you want to pass any other value to numerical definitions, they must be given as keyword lists.

Hence the dim = opts[:dim].

Vectorization

Next I noticed that my function worked with single keys, but failed when I passed a vectorized tensor of multiple keys (to generate multiple random unit vectors). That’s because initially I called devectorize to remove the temporary vectorization axis I introduced for generating one random value per dimension:

axis = :random_unit_vector_key
v =
  key
  |> Random.split(parts: dim)
  |> vectorize(axis)
  |> Random.normal_split(0, 1)
  |> devectorize()

That works fine when the caller gives us a scalar key, but when the caller gives us a vectorized scalar for key, devectorize will not only remove my temporary vectorization axis but also any of the caller’s axes! revectorize lets us restore the original vectorized_axes while devectorizing the temporary axis.

For a while I didn’t realize that that’s why I was getting just a bunch of non-unit vectors. The problem was that LinAlg.norm(v) was happily computing the summed norm of all of the vectors instead of giving me one norm per each vector! Then the final v / n was dividing all of the vectors by the same scalar.

I tried using revectorize(vectorized_axes ++ [foo: :auto]) but that actually does nothing. It just renames axis to foo.

The documentation has this identity for revectorize which didn’t really help me that much, because it mentions names like vectorized_sizes which it doesn’t reference subsequently:

assert revectorize(tensor, target_axes,
         target_shape: target_shape,
         target_names: target_names
       ) =
         tensor
         |> Nx.devectorize(keep_names: false)
         |> Nx.reshape(vectorized_sizes ++ target_shape, names: target_names)
         |> Nx.vectorize(vectorized_names)

It turns out vectorized_sizes and vectorized_names are the keys and values of target_axes, e.g.

target_axes = [foo: 1, bar: 2]
vectorized_sizes = {1, 2}
vectorized_names = [:foo, :bar]

Personally, these identities help me more:

tensor = # ...
vectorized_axes = tensor.vectorized_axes

assert tensor = revectorize(tensor, vectorized_axes)

assert tensor =
         tensor
         |> vectorize(:foo)
         |> revectorize(vectorized_axes,
           target_shape: Nx.shape(tensor),
           target_names: Nx.names(tensor)
         )

assert tensor |> vectorize(:foo) = revectorize(tensor, vectorized_axes ++ [foo: :auto])