What Emily Could Not Teach Us

May 2026 Nx.Vulkan eXMC Architecture

The seventeen-failure number had been stuck for three days. The eXMC test suite under EXMC_COMPILER=vulkan would chew through six hundred or so passes and then time out on every test that actually ran a NUTS sampler. We had Day 1's persistent buffer pool, we had Day 5's stable inverse-softplus, we had Day 6's f64 elementwise shaders, we had a four-input fused-chain compiler, we had right-folded chain detection, we had a peephole that turned multiply(p, p) into square(p). The fused chain shader handled up to eight ops in a single dispatch with measured one-point-six to four-x speedup at the C++ level. None of it was enough. The leapfrog body took about two hundred eighty microseconds per dispatch on the RTX 3060 Ti, multiplied by thirty ops per step, multiplied by a thousand steps, multiplied by N chains, equals: tests that needed to finish inside sixty seconds were taking eight.

The auto-fusion compiler we had been refining had grown to roughly six hundred lines. It walked the Nx.Defn.Expr IR top-down. It recognised left-folded chains where the chain bottom was the first arg of every binary op. It recognised right-folded chains where the chain bottom was the second arg, with sub-expression pre-evaluation. It handled commutative swaps. It assigned buffer indices in a small search across permutations of the input parameters. It had three hundred lines of tests. It worked, mostly, on the patterns we had tested.

Then I read Emily's source, and I deleted most of the compiler.

The shape of the problem

Nx.Vulkan dispatches every Nx.add, Nx.multiply, Nx.exp as a separate Vulkan compute shader. That is the default Nx contract. The cost of one dispatch on FreeBSD or Linux under our current implementation breaks down roughly as follows: twenty microseconds for the BEAM-NIF-BEAM crossing, fifty to a hundred for command buffer recording, fifty to a hundred for vkQueueSubmit plus fence wait, twenty for descriptor set write. Two hundred eighty microseconds per dispatch in steady state. EXLA-on-CUDA does the same workload in fifty milliseconds because XLA fuses the entire leapfrog into one CUDA kernel.

The classical answer is what we built: an IR-walking compiler that recognises chains and emits one fused dispatch per chain. The mathematics is straightforward. The implementation is not. Each new pattern shape — one-arg unary, two-arg binary, three-arg with sub-chain pre-evaluation, four-arg with assigned buffer indices, non-commutative reverse, peephole rewrites — is more code in the walker, more cases that can silently miss, and more tests to maintain. The code grew until it covered the canonical patterns we had benchmarked. It did not grow to cover the patterns eXMC actually used in production.

The diagnostic that exposed this was a microbenchmark that ran the canonical leapfrog body fn q, p -> Nx.add(q, Nx.multiply(p, p)) end through both Nx.Vulkan.Compiler and Nx.Defn.Evaluator, with a debug log printing whether the chain detector fired. The detector matched on contrived patterns we had designed for. It missed on the realistic shapes eXMC actually wrote. Adding the right-folded extension fixed that one. There were others.

Reading Emily

Emily is a separate project — Elixir bindings and Nx backend for Apple's MLX, designed to run Bumblebee models on Apple Silicon. It solves a structurally similar problem: an Nx backend that needs to dispatch fused kernels for performance-critical paths.

Emily's compiler is one hundred forty-one lines. It validates options. It delegates to Nx.Defn.Evaluator. That is all it does. The moduledoc explains the choice in plain language: "The compiler does not wrap mlx::core::compile. The bench harness under bench/native/ measured the fusion win at less than one-point-two-x on transformer-shaped workloads — below the threshold that justified the integration cost."

Emily decided not to do IR-level fusion. They had measured what the compiler-driven fusion path would buy them, and the answer was: not enough. Instead they ship something called Emily.Fast.

Emily.Fast is a module of named kernels. There is a function called layer_norm. There is a function called rms_norm. There is a function called scaled_dot_product_attention. Each one is between fifteen and forty lines. Each one looks roughly like this:

def layer_norm(x, weight, bias, opts \\ []) do
  opts = Keyword.validate!(opts, eps: 1.0e-5)
  Expr.optional(:fast_layer_norm, [x, weight, bias, opts], &layer_norm_fallback/4)
end

defp layer_norm_fallback(x, weight, bias, opts) do
  # the same operation built from primitive Nx ops
  ...
end

The Nx.Defn.Expr.optional/3 call is the keystone. It emits an IR node tagged with the name :fast_layer_norm and a fallback function. At evaluation time, the Nx evaluator looks at the active backend and asks: does this backend export a function called fast_layer_norm with the right arity? If yes, dispatch directly to it — one fused MLX kernel under Emily.Backend. If no, run the fallback. The fallback is just composed Nx operations; it works correctly on every backend that implements the Nx primitives.

That is the entire mechanism. There is no compiler walking the IR looking for fusable patterns. The user explicitly says, "here, fuse this," at the call site. The backend obliges or falls back. It is so simple it almost reads like a disappointment, until you sit with it for a few minutes and realise all the things it does not need.

The things it does not need

The named-kernel approach does not need a pattern matcher in the compiler, because the user has already named the pattern. It does not need a shrinker to reduce false negatives, because the call site is explicit and unambiguous. It does not need cross-backend coordination, because the fallback runs anywhere — the same defn body, on EXLA, on Nx.BinaryBackend, on EMLX, on Emily, on Nx.Vulkan. It does not need new compiler code to add a new fused operation; it needs one new function in Fast and one new callback in Backend. The IR walker we had built was solving a problem that the framework already had a primitive for.

There is an argument that automatic detection is intrinsically more valuable: the user does not need to know about fused kernels, the compiler finds them. This argument is correct in the steady state of a mature compiler. It is wrong in the early state of a vendor backend where the set of useful fusions is small, well-understood, and unlikely to grow without careful design. The compiler-driven path front-loads work to handle patterns that may not exist; the named-kernel path matches work to demonstrated demand. Emily has fifteen Fast functions. EXLA's CUDA fusions are similar in count. The distribution is not infinite. There is no reason to build infrastructure as if it were.

The refactor

The new module is called Nx.Vulkan.Fast. It is fifty-six lines. It contains four kernels, named for what they do in the context of NUTS sampling:

KernelDefn fallbackFused dispatch
leapfrog_position(q, eps, p)q + eps * pfused_chain_4 [multiply: 1, add: 2]
leapfrog_momentum_half(p, half_eps, grad)p + half_eps * gradsame shader
momentum_step(p, eps, grad)p + eps * gradsame shader
inv_mass_apply(p, inv_mass)p * inv_masstwo-input multiply

Each kernel is fifteen lines — an entry function that emits Nx.Defn.Expr.optional/3, and a private _fallback that builds the same operation from Nx.add / Nx.multiply. The Nx.Vulkan.Backend module gains four new callbacks matching those names. Each callback is roughly twenty-five lines: verify the operands' types and shapes, dispatch the four-input fused chain, return the result tensor. If types or shapes do not match, the callback transfers operands to Nx.BinaryBackend and runs the fallback there — a safety net that ensures correctness even when the fast path cannot apply.

The eXMC NUTS leapfrog calls these named functions directly. Where the body used to be:

defn leapfrog(q, p, eps, grad, half_eps) do
  p_half = Nx.add(p, Nx.multiply(half_eps, grad))
  q_new  = Nx.add(q, Nx.multiply(eps, p_half))
  grad_new = log_prob_grad(q_new)
  p_new  = Nx.add(p_half, Nx.multiply(half_eps, grad_new))
  ...
end

It becomes:

defn leapfrog(q, p, eps, grad, half_eps) do
  p_half = Nx.Vulkan.Fast.leapfrog_momentum_half(p, half_eps, grad)
  q_new  = Nx.Vulkan.Fast.leapfrog_position(q, eps, p_half)
  grad_new = log_prob_grad(q_new)
  p_new  = Nx.Vulkan.Fast.leapfrog_momentum_half(p_half, half_eps, grad_new)
  ...
end

The diff is mechanical: replace each elementwise expression with a named call, leave everything else alone. The function still works on EXLA — Nx.Vulkan.Fast.leapfrog_position falls back to Nx.add(q, Nx.multiply(eps, p)) on EXLA, which then dispatches through XLA's own fusion. It still works on Nx.BinaryBackend. It runs faster only on Nx.Vulkan, because that is the backend with the fused-kernel callback.

What got deleted

The IR-walker stays. It has correct behaviour for the patterns it catches; ripping it out would be churn for no benefit. It moves from the primary path to the automatic fallback for unannotated code. Code that explicitly calls Nx.Vulkan.Fast does not pass through the walker. Code that does not, still gets whatever the walker can detect.

What got deleted is the plan to extend the walker further. There were several open extensions: handling five-arg defns, mid-chain unary on the b-side, sub-chain materialization for arbitrary right-side expressions, a peephole library for common idioms. Each of those would be a hundred to three hundred lines. Each of those is no longer necessary, because the named-kernel path covers the cases they would have covered, with less code and explicit caller intent.

The microbenchmark

We refactored eXMC's NUTS leapfrog to call Nx.Vulkan.Fast. We rebuilt. We ran a microbenchmark of one leapfrog body with d = 8. The numbers came back in one second. They were not the numbers we expected.

Pathµs / leapfrog bodyvs. naive
Naive (direct Nx ops, Evaluator)4851.00×
Nx.Vulkan.Fast.* (named kernels via Expr.optional)43110.11× (9× slower)
IR-walking compiler (one fused dispatch on the recognised chain)3881.25×

The named-kernel architecture we had just spent a session building was nearly an order of magnitude slower than the naive direct-dispatch path it was supposed to replace. The IR walker we had been planning to delete was the fastest of the three.

Why

Every Nx.Defn.Expr.optional/3 node in the IR adds a layer of indirection at evaluation time. The Nx evaluator has to: look up whether the active backend exports the named callback, decode the argument list, invoke function_exported?, dispatch through the Erlang dynamic call machinery, and arrive at the backend function. The backend function then performs its own argument validation, transfers operands to the GPU if necessary, and finally calls the NIF that does the work. On the RTX 3060 Ti we measured this overhead at roughly seven hundred microseconds per Fast call.

A NUTS leapfrog body has six elementwise operations. Six times seven hundred microseconds is forty-two hundred microseconds. Add the broadcast pair for the scalar eps and you arrive at the 4311 figure the benchmark reported. The naive path runs the same six ops as primitive Nx calls; each primitive backend dispatch costs roughly a hundred fifty microseconds because it does not pay the optional-callback indirection. The naive sum is one hundred fifty times seven equals roughly one thousand microseconds. The numbers match.

For Bumblebee — for layer norm against a four-thousand-element hidden state, for a scaled-dot-product attention against an eight- thousand-by-eight-thousand mask — the GPU work inside one Fast call is several milliseconds. The seven hundred microseconds of dispatch overhead disappears against the substrate work. Emily's named-kernel path wins because the kernels are fat. Our MCMC vectors are eight elements, and the dispatch overhead floor is the bottom line.

The IR-walking compiler we had been disparaging dispatches the entire chain through one backend callback. There is no per-op optional indirection. There is one trace, one dispatch, one fused shader. For narrow shapes that match its pattern, it pays the dispatch overhead once. The benchmark shows the consequence: 388 microseconds end to end, one and a quarter times faster than naive, four-fifths the cost of even three primitive ops.

What we kept and what we deleted

We kept the IR-walking compiler. The right-folded chain detection. The four-input fused dispatch. The auto-detect for three- and four-argument defn bodies. All the work we had spent two sessions building, that we had been about to declare obsolete. The microbenchmark put it back as the primary path.

We kept the Nx.Vulkan.Fast module. The architectural decoupling is sound — named kernels at the call site, fallbacks that work cross-backend, optional dispatch via Nx's own primitive. For workloads where each fused kernel does substantial GPU work, the named-kernel path is correct. We just don't have those workloads. The module ships as an explicit-opt-in API for users who do.

We deleted the eXMC leapfrog refactor. The diagonal-mass step went back to the original direct Nx-op path. The blog post that we had already written, declaring victory on the architectural shift, disappeared into the editor and re-emerged as this one.

The lesson, revised

The Emily pattern is not a universal upgrade over IR-walking compilers. It is a different point on the cost curve. Named kernels amortise their dispatch overhead against fat per-kernel work; the break-even depends on how fat. For elementwise chains on tensors of eight or sixty-four floats, the break-even is unfavourable, and the unmodified Nx primitive dispatch path beats the named-kernel path by a factor of nine.

The cost of not measuring before refactoring is the cost of refactoring. We had read Emily's source carefully. We had drawn the architectural analogy. We had not run the comparison benchmark. The benchmark was three lines — same fn, same args, two compiler choices, one ratio. Three lines would have saved a session.

The next time we are tempted by an architectural shift on the basis of a different project's design choices, we will run the microbenchmark first. The benchmark is always cheaper than the refactor. The benchmark is always cheaper than the blog post you have to rewrite.

What stays

The IR walker. The four-input fused shader. The auto-detect for three- and four-argument defn bodies. Nx.Vulkan.Compiler remains the primary path for typical defn-traced code. Nx.Vulkan.Fast remains an explicit opt-in, kept honest by its fallbacks. The benchmark stays in the repository as a test fixture, where the next person to consider this question can run it and read the answer in one second.


The Nx.Vulkan.Fast module is at lib/nx_vulkan/fast.ex. The matching backend callbacks are in lib/nx_vulkan/backend.ex. The Emily source that prompted the refactor is at elixir-nx/emily/lib/emily/fast.ex. The four-input fused shader the kernels dispatch under the hood is at spirit/shaders/fused_elementwise_4in.comp, written by mac-248 across three increments in late April 2026.