The seventeen-failure number had been stuck for three days. The eXMC
test suite under EXMC_COMPILER=vulkan would chew through
six hundred or so passes and then time out on every test that
actually ran a NUTS sampler. We had Day 1's persistent buffer pool, we
had Day 5's stable inverse-softplus, we had Day 6's f64 elementwise
shaders, we had a four-input fused-chain compiler, we had right-folded
chain detection, we had a peephole that turned multiply(p, p)
into square(p). The fused chain shader handled up to
eight ops in a single dispatch with measured one-point-six to four-x
speedup at the C++ level. None of it was enough. The leapfrog body
took about two hundred eighty microseconds per dispatch on the RTX
3060 Ti, multiplied by thirty ops per step, multiplied by a thousand
steps, multiplied by N chains, equals: tests that needed to finish
inside sixty seconds were taking eight.
The auto-fusion compiler we had been refining had grown to roughly six
hundred lines. It walked the Nx.Defn.Expr IR top-down. It
recognised left-folded chains where the chain bottom was the first arg
of every binary op. It recognised right-folded chains where the
chain bottom was the second arg, with sub-expression pre-evaluation. It
handled commutative swaps. It assigned buffer indices in a small
search across permutations of the input parameters. It had three
hundred lines of tests. It worked, mostly, on the patterns we had
tested.
Then I read Emily's source, and I deleted most of the compiler.
The shape of the problem
Nx.Vulkan dispatches every Nx.add, Nx.multiply,
Nx.exp as a separate Vulkan compute shader. That is the
default Nx contract. The cost of one dispatch on FreeBSD or Linux
under our current implementation breaks down roughly as follows:
twenty microseconds for the BEAM-NIF-BEAM crossing, fifty to a
hundred for command buffer recording, fifty to a hundred for
vkQueueSubmit plus fence wait, twenty for descriptor set
write. Two hundred eighty microseconds per dispatch in steady state.
EXLA-on-CUDA does the same workload in fifty milliseconds because
XLA fuses the entire leapfrog into one CUDA kernel.
The classical answer is what we built: an IR-walking compiler that recognises chains and emits one fused dispatch per chain. The mathematics is straightforward. The implementation is not. Each new pattern shape — one-arg unary, two-arg binary, three-arg with sub-chain pre-evaluation, four-arg with assigned buffer indices, non-commutative reverse, peephole rewrites — is more code in the walker, more cases that can silently miss, and more tests to maintain. The code grew until it covered the canonical patterns we had benchmarked. It did not grow to cover the patterns eXMC actually used in production.
The diagnostic that exposed this was a microbenchmark that ran the
canonical leapfrog body
fn q, p -> Nx.add(q, Nx.multiply(p, p)) end through
both Nx.Vulkan.Compiler and Nx.Defn.Evaluator,
with a debug log printing whether the chain detector fired. The
detector matched on contrived patterns we had designed for. It missed
on the realistic shapes eXMC actually wrote. Adding the right-folded
extension fixed that one. There were others.
Reading Emily
Emily is a separate project — Elixir bindings and Nx backend for Apple's MLX, designed to run Bumblebee models on Apple Silicon. It solves a structurally similar problem: an Nx backend that needs to dispatch fused kernels for performance-critical paths.
Emily's compiler is one hundred forty-one lines. It validates options.
It delegates to Nx.Defn.Evaluator. That is all it does.
The moduledoc explains the choice in plain language: "The compiler
does not wrap mlx::core::compile. The bench harness under bench/native/
measured the fusion win at less than one-point-two-x on
transformer-shaped workloads — below the threshold that
justified the integration cost."
Emily decided not to do IR-level fusion. They had measured what the
compiler-driven fusion path would buy them, and the answer was: not
enough. Instead they ship something called Emily.Fast.
Emily.Fast is a module of named kernels. There is a
function called layer_norm. There is a function called
rms_norm. There is a function called
scaled_dot_product_attention. Each one is between fifteen
and forty lines. Each one looks roughly like this:
def layer_norm(x, weight, bias, opts \\ []) do
opts = Keyword.validate!(opts, eps: 1.0e-5)
Expr.optional(:fast_layer_norm, [x, weight, bias, opts], &layer_norm_fallback/4)
end
defp layer_norm_fallback(x, weight, bias, opts) do
# the same operation built from primitive Nx ops
...
end
The Nx.Defn.Expr.optional/3 call is the keystone. It
emits an IR node tagged with the name :fast_layer_norm
and a fallback function. At evaluation time, the Nx evaluator looks
at the active backend and asks: does this backend export a
function called fast_layer_norm with the right arity?
If yes, dispatch directly to it — one fused MLX kernel under
Emily.Backend. If no, run the fallback. The fallback is
just composed Nx operations; it works correctly on every backend that
implements the Nx primitives.
That is the entire mechanism. There is no compiler walking the IR looking for fusable patterns. The user explicitly says, "here, fuse this," at the call site. The backend obliges or falls back. It is so simple it almost reads like a disappointment, until you sit with it for a few minutes and realise all the things it does not need.
The things it does not need
The named-kernel approach does not need a pattern matcher in the
compiler, because the user has already named the pattern. It does not
need a shrinker to reduce false negatives, because the call site is
explicit and unambiguous. It does not need cross-backend coordination,
because the fallback runs anywhere — the same defn body, on
EXLA, on Nx.BinaryBackend, on EMLX, on Emily, on Nx.Vulkan. It does
not need new compiler code to add a new fused operation; it needs one
new function in Fast and one new callback in
Backend. The IR walker we had built was solving a problem
that the framework already had a primitive for.
There is an argument that automatic detection is intrinsically more
valuable: the user does not need to know about fused kernels, the
compiler finds them. This argument is correct in the steady state of a
mature compiler. It is wrong in the early state of a vendor backend
where the set of useful fusions is small, well-understood, and
unlikely to grow without careful design. The compiler-driven path
front-loads work to handle patterns that may not exist; the
named-kernel path matches work to demonstrated demand. Emily has
fifteen Fast functions. EXLA's CUDA fusions are similar
in count. The distribution is not infinite. There is no reason to
build infrastructure as if it were.
The refactor
The new module is called Nx.Vulkan.Fast. It is fifty-six
lines. It contains four kernels, named for what they do in the
context of NUTS sampling:
| Kernel | Defn fallback | Fused dispatch |
|---|---|---|
leapfrog_position(q, eps, p) | q + eps * p | fused_chain_4 [multiply: 1, add: 2] |
leapfrog_momentum_half(p, half_eps, grad) | p + half_eps * grad | same shader |
momentum_step(p, eps, grad) | p + eps * grad | same shader |
inv_mass_apply(p, inv_mass) | p * inv_mass | two-input multiply |
Each kernel is fifteen lines — an entry function that emits
Nx.Defn.Expr.optional/3, and a private
_fallback that builds the same operation from
Nx.add / Nx.multiply. The
Nx.Vulkan.Backend module gains four new callbacks
matching those names. Each callback is roughly twenty-five lines:
verify the operands' types and shapes, dispatch the four-input fused
chain, return the result tensor. If types or shapes do not match, the
callback transfers operands to Nx.BinaryBackend and runs
the fallback there — a safety net that ensures correctness even
when the fast path cannot apply.
The eXMC NUTS leapfrog calls these named functions directly. Where the body used to be:
defn leapfrog(q, p, eps, grad, half_eps) do
p_half = Nx.add(p, Nx.multiply(half_eps, grad))
q_new = Nx.add(q, Nx.multiply(eps, p_half))
grad_new = log_prob_grad(q_new)
p_new = Nx.add(p_half, Nx.multiply(half_eps, grad_new))
...
end
It becomes:
defn leapfrog(q, p, eps, grad, half_eps) do
p_half = Nx.Vulkan.Fast.leapfrog_momentum_half(p, half_eps, grad)
q_new = Nx.Vulkan.Fast.leapfrog_position(q, eps, p_half)
grad_new = log_prob_grad(q_new)
p_new = Nx.Vulkan.Fast.leapfrog_momentum_half(p_half, half_eps, grad_new)
...
end
The diff is mechanical: replace each elementwise expression with a
named call, leave everything else alone. The function still works on
EXLA — Nx.Vulkan.Fast.leapfrog_position falls back
to Nx.add(q, Nx.multiply(eps, p)) on EXLA, which then
dispatches through XLA's own fusion. It still works on
Nx.BinaryBackend. It runs faster only on Nx.Vulkan,
because that is the backend with the fused-kernel callback.
What got deleted
The IR-walker stays. It has correct behaviour for the patterns it
catches; ripping it out would be churn for no benefit. It moves from
the primary path to the automatic fallback for
unannotated code. Code that explicitly calls
Nx.Vulkan.Fast does not pass through the walker. Code
that does not, still gets whatever the walker can detect.
What got deleted is the plan to extend the walker further. There were several open extensions: handling five-arg defns, mid-chain unary on the b-side, sub-chain materialization for arbitrary right-side expressions, a peephole library for common idioms. Each of those would be a hundred to three hundred lines. Each of those is no longer necessary, because the named-kernel path covers the cases they would have covered, with less code and explicit caller intent.
The microbenchmark
We refactored eXMC's NUTS leapfrog to call Nx.Vulkan.Fast.
We rebuilt. We ran a microbenchmark of one leapfrog body with
d = 8. The numbers came back in one second. They were not the
numbers we expected.
| Path | µs / leapfrog body | vs. naive |
|---|---|---|
| Naive (direct Nx ops, Evaluator) | 485 | 1.00× |
Nx.Vulkan.Fast.* (named kernels via Expr.optional) | 4311 | 0.11× (9× slower) |
| IR-walking compiler (one fused dispatch on the recognised chain) | 388 | 1.25× |
The named-kernel architecture we had just spent a session building was nearly an order of magnitude slower than the naive direct-dispatch path it was supposed to replace. The IR walker we had been planning to delete was the fastest of the three.
Why
Every Nx.Defn.Expr.optional/3 node in the IR adds a layer
of indirection at evaluation time. The Nx evaluator has to: look up
whether the active backend exports the named callback, decode the
argument list, invoke function_exported?, dispatch
through the Erlang dynamic call machinery, and arrive at the backend
function. The backend function then performs its own argument
validation, transfers operands to the GPU if necessary, and finally
calls the NIF that does the work. On the RTX 3060 Ti we measured
this overhead at roughly seven hundred microseconds per
Fast call.
A NUTS leapfrog body has six elementwise operations. Six times seven
hundred microseconds is forty-two hundred microseconds. Add the
broadcast pair for the scalar eps and you arrive at the
4311 figure the benchmark reported. The naive path runs the same six
ops as primitive Nx calls; each primitive backend dispatch costs
roughly a hundred fifty microseconds because it does not pay the
optional-callback indirection. The naive sum is one hundred fifty
times seven equals roughly one thousand microseconds. The numbers
match.
For Bumblebee — for layer norm against a four-thousand-element
hidden state, for a scaled-dot-product attention against an eight-
thousand-by-eight-thousand mask — the GPU work inside one
Fast call is several milliseconds. The seven hundred
microseconds of dispatch overhead disappears against the substrate
work. Emily's named-kernel path wins because the kernels are fat. Our
MCMC vectors are eight elements, and the dispatch overhead floor is
the bottom line.
The IR-walking compiler we had been disparaging dispatches the entire chain through one backend callback. There is no per-op optional indirection. There is one trace, one dispatch, one fused shader. For narrow shapes that match its pattern, it pays the dispatch overhead once. The benchmark shows the consequence: 388 microseconds end to end, one and a quarter times faster than naive, four-fifths the cost of even three primitive ops.
What we kept and what we deleted
We kept the IR-walking compiler. The right-folded chain detection. The four-input fused dispatch. The auto-detect for three- and four-argument defn bodies. All the work we had spent two sessions building, that we had been about to declare obsolete. The microbenchmark put it back as the primary path.
We kept the Nx.Vulkan.Fast module. The architectural
decoupling is sound — named kernels at the call site, fallbacks
that work cross-backend, optional dispatch via Nx's own primitive.
For workloads where each fused kernel does substantial GPU work, the
named-kernel path is correct. We just don't have those workloads. The
module ships as an explicit-opt-in API for users who do.
We deleted the eXMC leapfrog refactor. The diagonal-mass step went back to the original direct Nx-op path. The blog post that we had already written, declaring victory on the architectural shift, disappeared into the editor and re-emerged as this one.
The lesson, revised
The Emily pattern is not a universal upgrade over IR-walking compilers. It is a different point on the cost curve. Named kernels amortise their dispatch overhead against fat per-kernel work; the break-even depends on how fat. For elementwise chains on tensors of eight or sixty-four floats, the break-even is unfavourable, and the unmodified Nx primitive dispatch path beats the named-kernel path by a factor of nine.
The cost of not measuring before refactoring is the cost of refactoring. We had read Emily's source carefully. We had drawn the architectural analogy. We had not run the comparison benchmark. The benchmark was three lines — same fn, same args, two compiler choices, one ratio. Three lines would have saved a session.
The next time we are tempted by an architectural shift on the basis of a different project's design choices, we will run the microbenchmark first. The benchmark is always cheaper than the refactor. The benchmark is always cheaper than the blog post you have to rewrite.
What stays
The IR walker. The four-input fused shader. The auto-detect for three-
and four-argument defn bodies. Nx.Vulkan.Compiler remains
the primary path for typical defn-traced code. Nx.Vulkan.Fast
remains an explicit opt-in, kept honest by its fallbacks. The
benchmark stays in the repository as a test fixture, where the next
person to consider this question can run it and read the answer in
one second.
The Nx.Vulkan.Fast module is at
lib/nx_vulkan/fast.ex.
The matching backend callbacks are in
lib/nx_vulkan/backend.ex.
The Emily source that prompted the refactor is at
elixir-nx/emily/lib/emily/fast.ex.
The four-input fused shader the kernels dispatch under the hood is at
spirit/shaders/fused_elementwise_4in.comp,
written by mac-248 across three increments in late April 2026.