Clean-room Rust reimplementation of JAX's transform semantics.
Semantic fidelity. Mathematical rigor. Operational safety. Profile-proven performance.
The problem: JAX's transform semantics (jit, grad, vmap) are deeply entangled with Python and XLA. There's no standalone, portable, verifiable implementation of the mathematical core.
The solution: FrankenJAX extracts and reimplements JAX's transform composition model in Rust with a canonical JAXPR-like IR, full automatic differentiation, and a differential conformance harness that validates every primitive against the real JAX oracle.
Why FrankenJAX?
| Feature | Status |
|---|---|
| 110 primitive operations with full eval | All green |
| Reverse-mode (VJP) + Forward-mode (JVP) AD for all 110 primitives | All green |
Transform composition: jit(grad(f)), vmap(grad(f)), grad(grad(f)) |
All green |
| Linear algebra: Cholesky, QR, SVD, Eigh, TriangularSolve (eval + AD) | All green |
| FFT: Fft, Ifft, Rfft, Irfft (eval + AD) | All green |
| E-graph equality saturation optimizer (87 algebraic rewrite rules) | All green |
| 834 JAX oracle fixture cases for differential conformance | All green |
| RaptorQ erasure-coded durability for all long-lived artifacts | All green |
| Strict/Hardened compatibility-security mode split | All green |
1,867 #[test] cases + proptest suites |
All passing |
| FrankenJAX | JAX (Google) | PyTorch | Enzyme (LLVM) | Autograd | |
|---|---|---|---|---|---|
| Language | Rust | Python/C++ | Python/C++ | LLVM IR | Python |
| Runtime dependency | None (standalone) | Python + XLA + CUDA | Python + CUDA | LLVM toolchain | NumPy |
| Transform composition | Full (jit/grad/vmap) |
Full | Limited (torch.func) |
grad only |
grad only |
| Verifiable proofs | Trace Transform Ledger | No | No | No | No |
| Oracle conformance | 834 JAX-verified cases | N/A (is the oracle) | No | No | No |
| Artifact durability | RaptorQ sidecars | No | No | No | No |
| E-graph optimization | 87 rules, equality saturation | XLA HLO passes | TorchScript/Inductor | LLVM passes | None |
| Embeddable | Yes (Rust library + C FFI) | No (Python required) | Partially (libtorch) | Yes (LLVM plugin) | No |
FrankenJAX is not a replacement for JAX in production ML training. It is a reference implementation of JAX's mathematical transform semantics that you can embed in Rust applications, use as a verification oracle, or study to understand how composable transforms work.
- Compiler researchers studying how
jit/grad/vmapcompose and interact - Rust developers who need automatic differentiation without Python
- Verification engineers who want auditable transform composition with proof artifacts
- ML framework developers who need a reference implementation of JAX semantics
- Educators teaching automatic differentiation, functional transforms, or IR design
use fj_api::{jit, grad, vmap, jacobian, hessian, make_jaxpr};
// Trace a Rust closure into canonical IR
let jaxpr = make_jaxpr(|x| x * x + 3.0 * x)?;
// Apply transforms exactly like JAX
let result = jit(jaxpr.clone()).call(vec![Value::scalar_f64(5.0)])?;
let gradient = grad(jaxpr.clone()).call(vec![Value::scalar_f64(5.0)])?;
// Compose transforms
let batched_grad = jit(jaxpr.clone()).compose_grad();
let J = jacobian(jaxpr.clone()).call(vec![Value::scalar_f64(5.0)])?;
let H = hessian(jaxpr).call(vec![Value::scalar_f64(5.0)])?;What happens internally when you compute grad(f)(3.0) for f(x) = x^2 + 3x:
Step 1: Trace. make_jaxpr runs the closure with a tracer value, recording operations:
v1 = input
v2 = mul(v1, v1) # x^2
v3 = mul(3.0, v1) # 3x (3.0 stored as Atom::Lit)
v4 = add(v2, v3) # x^2 + 3x
output = v4
Step 2: Dispatch. The grad transform tells the dispatcher to apply reverse-mode AD.
Step 3: Forward pass. Evaluate with x=3.0, recording the tape:
v1 = 3.0
v2 = mul(3.0, 3.0) = 9.0 tape: [mul, inputs=[3.0, 3.0], output=9.0]
v3 = mul(3.0, 3.0) = 9.0 tape: [mul, inputs=[3.0, 3.0], output=9.0]
v4 = add(9.0, 9.0) = 18.0 tape: [add, inputs=[9.0, 9.0], output=18.0]
Step 4: Backward pass. Walk the tape in reverse with output cotangent = 1.0:
g_v4 = 1.0 (seed)
add_vjp: g_v2 += 1.0, g_v3 += 1.0 (add gradient = pass-through)
mul_vjp for v3: g_v1 += 3.0 * 1.0 = 3.0 (d(3x)/dx = 3)
mul_vjp for v2: g_v1 += 3.0 * 1.0 + 3.0 * 1.0 = 6.0 (d(x^2)/dx = 2x = 6)
g_v1 total = 3.0 + 6.0 = 9.0
Result: gradient = 9.0, which matches d/dx(x^2 + 3x)|_{x=3} = 2(3) + 3 = 9. Confirmed against JAX oracle.
1. Transform composition semantics are non-negotiable. Every transform composition produces a verifiable proof artifact linking input IR, applied transforms, and output IR via the Trace Transform Ledger (TTL).
2. Differential conformance, not reimplementation faith. Every primitive's behavior is validated against real JAX 0.9.2 oracle fixtures. We verify, not trust, our implementations: 834 oracle test cases spanning transforms, AD, linalg, FFT, RNG, dtype promotion, and transform composition.
3. Strict/Hardened mode split. Strict mode maximizes observable compatibility. Hardened mode adds safety guards with bounded defensive recovery. You choose the tradeoff per invocation.
4. RaptorQ-everywhere durability. Long-lived artifacts (fixtures, baselines, ledgers) get erasure-coded sidecars with scrub reports and decode proofs. Bit rot does not get to silently corrupt your conformance evidence.
5. Correctness is measured, not assumed. Numerical AD rules are validated against finite-difference gradients. The Cholesky VJP/JVP bugs we found and fixed? Found by numerical verification tests, not by staring at formulas.
User API (fj-api)
|
v
Trace (fj-trace) --> Canonical IR (fj-core: Jaxpr + 110 Primitives)
|
v
Transform Stack (fj-dispatch)
| jit grad vmap
| \ | /
v v v v
+-- E-graph optimizer (fj-egraph: 87 rewrite rules)
+-- AD engine (fj-ad: VJP + JVP for all 110 primitives)
+-- Batch trace (fj-dispatch/batching: per-primitive vmap rules)
+-- Evidence ledger (fj-ledger: transform composition proofs)
|
v
Lowering + Eval (fj-lax: arithmetic, linalg, FFT, tensor ops)
|
v
CPU Backend (fj-backend-cpu: dependency-wave parallel executor)
|
v
Cache (fj-cache: SHA-256 deterministic keys, strict/hardened gates)
| Crate | Purpose | Tests |
|---|---|---|
fj-core |
Canonical IR (Jaxpr), 110 primitives, 11 dtypes, shapes, value model | Extensive |
fj-lax |
Primitive evaluation: arithmetic, linalg, FFT, reductions, tensor ops | 479 |
fj-ad |
Automatic differentiation: VJP + JVP for all 110 primitives | 179 |
fj-dispatch |
Transform dispatch, order-sensitive composition, batching | 55+ |
fj-trace |
make_jaxpr tracing from Rust closures, nested trace contexts |
50 |
fj-egraph |
E-graph equality saturation: 87 algebraic rewrite rules | 47 |
fj-api |
User-facing API: jit, grad, vmap, jacobian, hessian |
38 |
fj-cache |
Deterministic cache keys, strict/hardened gate behavior | Yes |
fj-ledger |
Decision/evidence ledger, loss-matrix actions, audit trail | Yes |
fj-runtime |
Tensor-aware runtime value model, optional async integration | Yes |
fj-interpreters |
Scoped primitive interpreter, partial evaluation, staging | Yes |
fj-conformance |
Differential conformance harness, 834 oracle fixtures, durability | 200+ |
fj-backend-cpu |
Dependency-wave parallel executor (rayon) | 40 |
fj-ffi |
C FFI surface (only crate permitted unsafe) |
Yes |
fj-test-utils |
Shared test scaffolding, fixture helpers | Yes |
80,872 lines of Rust across 15 crates with end-to-end trace -> dispatch -> runtime pipeline:
- 110 primitive operations covering arithmetic, trigonometric, hyperbolic, comparison, reduction, shape manipulation, linear algebra, FFT, bitwise, control flow, sorting, convolution, and special math functions
- 11 DTypes (BF16, F16, F32, F64, I32, I64, U32, U64, Bool, Complex64, Complex128) with JAX-verified type promotion rules
- Full AD coverage: all 110 primitives have both VJP (reverse-mode) and JVP (forward-mode) rules, including multi-output decompositions (Cholesky, QR, SVD, Eigh) and FFT
- Jacobian and Hessian matrix computation via composable AD
vmapwith per-primitive batching rules,in_axes/out_axes, BatchTrace interpreter- E-graph optimizer: 87 algebraic rewrite rules with equality saturation, verified to preserve program semantics
- ThreeFry2x32 RNG: key/split/fold_in/uniform/normal/bernoulli/categorical with JAX-matched determinism
- Control flow:
cond,scan,while_loop,fori_loop,switchwith AD support - 834 JAX oracle fixture cases captured from JAX 0.9.2 with x64 mode, covering transforms, AD, linalg, FFT, RNG, dtype promotion, and transform composition
- 1,867
#[test]cases plus proptest/property-based suites - RaptorQ durability pipeline for all long-lived evidence artifacts
At the center of FrankenJAX is the Jaxpr (JAX expression), a functional intermediate representation:
struct Jaxpr {
invars: Vec<VarId>, // Input variables (function parameters)
constvars: Vec<VarId>, // Constant bindings (captured values)
outvars: Vec<VarId>, // Output variables (return values)
equations: Vec<Equation>, // Sequence of primitive operations
}
struct Equation {
primitive: Primitive, // Which of the 110 operations
inputs: SmallVec<[Atom; 4]>, // Input references (variables or literals)
outputs: SmallVec<[VarId; 2]>, // Output variable bindings
params: BTreeMap<String, String>, // Operation parameters (axes, dtypes, etc.)
effects: Vec<...>, // Side-effect tokens
sub_jaxprs: Vec<Jaxpr>, // Nested Jaxprs (for control flow)
}
enum Atom {
Var(VarId), // Reference to a variable in the environment
Lit(Literal), // Inline constant value
}A Jaxpr is a straight-line program: no branches, no loops at the top level. Control flow (cond, scan, while_loop, switch) is expressed via primitives that take sub-Jaxprs as arguments. This design makes transforms (grad, vmap) straightforward: they operate equation-by-equation, and each primitive has well-defined transform rules.
Tracing: The make_jaxpr function traces a Rust closure by running it with abstract tracer values that record operations instead of computing them. The result is a Jaxpr that represents the computation graph.
FrankenJAX represents runtime values with a two-level model:
enum Value {
Scalar(Literal), // Single element
Tensor(TensorValue), // Multi-element with shape
}
enum Literal {
Bool(bool),
I64(i64), // Also used for I32 (widened internally)
U32(u32),
U64(u64),
BF16Bits(u16), // Stored as raw bits
F16Bits(u16), // Stored as raw bits
F64Bits(u64), // All floats stored as f64 bits
Complex64Bits(u32, u32), // (re_bits, im_bits)
Complex128Bits(u64, u64), // (re_bits, im_bits)
}
struct TensorValue {
dtype: DType, // Element type tag
shape: Shape, // Dimension vector: Vec<u32>
elements: Vec<Literal>, // Flat row-major storage
}Design rationale: Storing floats as bit patterns (F64Bits(u64)) rather than as f64 directly ensures exact round-trip serialization/deserialization without any NaN canonicalization surprises. The Literal enum is Copy, enabling cheap value passing through the interpreter. TensorValue uses flat row-major storage with a shape vector, matching NumPy/JAX's memory layout.
FrankenJAX supports Complex64 (32-bit real + 32-bit imag) and Complex128 (64-bit real + 64-bit imag) throughout the stack:
Arithmetic: Complex binary operations use proper field arithmetic:
- Addition: (a+bi) + (c+di) = (a+c) + (b+d)i
- Multiplication: (a+bi)(c+di) = (ac-bd) + (ad+bc)i
- Division: (a+bi)/(c+di) = ((ac+bd) + (bc-ad)i) / (c^2+d^2)
Primitives: Complex(re, im) constructs a complex value, Real(z) and Imag(z) extract components, Conj(z) conjugates. All trigonometric and exponential functions support complex arguments.
FFT: Complex-valued FFT/IFFT operates on Complex128 tensors directly. RFFT takes real input and produces Complex128 output (half-spectrum). IRFFT inverts this.
AD: VJP and JVP rules for complex operations follow the Wirtinger derivative convention where applicable. The e-graph optimizer has dedicated rules for complex simplification: real(complex(r, i)) → r, imag(complex(r, i)) → i, complex(real(z), imag(z)) → z.
FrankenJAX implements JAX's type promotion lattice for mixed-type operations. When you add an i64 and a f64, the result type follows this hierarchy:
Bool --> I32 --> I64 --------+
+--> F64
U32 --> U64 -----------------+
BF16 --> F32 --> F64
F16 --> F32 --> F64
Complex64 --> Complex128
Key promotion rules verified against JAX oracle:
- Integer + Float → Float (e.g.,
i64 + f64 → f64) - Unsigned + Signed → wider type or Float (e.g.,
u64 + i64 → f64) - Any + Complex → Complex
- Same-type operations preserve type (e.g.,
u32 + u32 → u32)
The full 9x9 promotion matrix (162 cases for add and multiply) is captured from JAX 0.9.2 and validated in CI.
FrankenJAX implements JAX's functional control flow primitives, which express branching and iteration within the Jaxpr IR:
| Primitive | Semantics | Sub-Jaxprs |
|---|---|---|
cond |
if pred then true_branch(args) else false_branch(args) |
2 (true, false) |
scan |
Fold/scan with carry: (carry, ys) = scan(body, init_carry, xs) |
1 (body) |
while_loop |
while cond(state): state = body(state) |
2 (cond, body) |
fori_loop |
for i in range(lo, hi): state = body(i, state) |
1 (body) |
switch |
Multi-way branch: branches[index](args) |
N (one per branch) |
These primitives compose with AD. grad(scan(f)) differentiates through the scan by unrolling the tape. vmap(cond(...)) batches the condition predicate and both branches.
The fj-ledger crate implements a decision-theoretic audit trail for runtime choices:
Every non-trivial decision (cache hit vs recompute, strict vs hardened recovery, optimization level) is logged as a DecisionRecord with:
- Action taken and alternatives considered
- Loss-matrix justification: expected cost of each alternative under the current mode
- Evidence signals: what information was available at decision time
- Conformal predictor: statistical confidence bound on the decision quality
This is a formal audit trail, not a debugging log. It answers "why did the system do X instead of Y?" for any execution. The ledger entries survive across sessions via the durability pipeline.
FrankenJAX uses four layers of correctness assurance, each catching different classes of bugs:
Layer 1: Oracle conformance (834 cases) Capture expected outputs from real JAX, replay against FrankenJAX. Catches: wrong evaluation logic, incorrect primitive semantics, dtype mismatches.
Layer 2: Numerical AD verification (9 tests) Compare analytical gradients (VJP/JVP rules) against finite-difference approximations. Catches: wrong derivative formulas, sign errors, missing terms. Found two real Cholesky bugs.
Layer 3: Property-based testing (proptest)
Generate random inputs and verify algebraic invariants: exp(log(x)) ≈ x, Q^T Q ≈ I for QR, U Σ V^T ≈ A for SVD. Catches: edge cases, numerical instability, overflow.
Layer 4: E-graph semantics preservation (12 tests) Run programs with and without optimization, verify identical results. Catches: rewrite rules that change program meaning, cost model bugs that prefer wrong programs.
Combined, these layers provide defense-in-depth: oracle tests catch "wrong answer" bugs, numerical tests catch "wrong gradient" bugs, property tests catch "crashes on weird input" bugs, and e-graph tests catch "optimization broke something" bugs.
FrankenJAX uses tape-based reverse-mode AD (backpropagation) with full forward-mode support:
Forward pass with tape recording:
The AD engine evaluates the Jaxpr equation-by-equation, recording each operation as a TapeEntry that captures the primitive, input/output values, and parameters. Multi-output primitives (QR, SVD, Eigh) store all primal outputs explicitly because their VJP rules need them for the matrix calculus.
Backward pass with cotangent threading: The tape is replayed in reverse. For each entry, the primitive's VJP rule computes input cotangents from output cotangents. When a variable appears in multiple equations, its gradients are accumulated (summed). An early-termination check skips entries where all output gradients are zero.
Forward: x --[mul]--> t1 --[add]--> y (record tape)
Backward: g_y --[add_vjp]--> g_t1 --[mul_vjp]--> g_x (replay reverse)
Each of the 110 primitives has hand-derived VJP and JVP rules. Here are a few concrete examples showing the mathematical formulas implemented:
| Primitive | VJP Rule (given output cotangent ḡ) | JVP Rule (given input tangent ẋ) |
|---|---|---|
mul(a, b) |
ḡ_a = ḡ * b, ḡ_b = ḡ * a | ẏ = ẋ_a * b + a * ẋ_b (product rule) |
sin(x) |
ḡ_x = ḡ * cos(x) | ẏ = ẋ * cos(x) |
exp(x) |
ḡ_x = ḡ * exp(x) | ẏ = ẋ * exp(x) |
reduce_sum(x) |
ḡ_x = broadcast(ḡ, x.shape) | ẏ = reduce_sum(ẋ) |
reshape(x) |
ḡ_x = reshape(ḡ, x.shape) | ẏ = reshape(ẋ, new_shape) |
transpose(x) |
ḡ_x = transpose(ḡ, inv_perm) | ẏ = transpose(ẋ, perm) |
Non-differentiable primitives (comparisons, bitwise ops, floor/ceil) return zero gradients, which is mathematically correct since their derivatives are zero almost everywhere.
The linalg decomposition VJPs follow Murray 2016 ("Differentiation of the Cholesky decomposition") and related literature, with diagonal-halving corrections for the Phi operator and careful triangular-solve direction for L^{-T} vs L^{-1} terms.
The optimizer converts Jaxpr programs into an e-graph (equivalence graph) and applies equality saturation via the egg library. Instead of choosing one rewrite direction, all applicable rewrites fire simultaneously, and a cost function extracts the cheapest equivalent program.
87 algebraic rewrite rules, including:
| Category | Example Rule | Effect |
|---|---|---|
| Arithmetic identities | x + 0 → x, x * 1 → x |
Eliminate no-ops |
| Strength reduction | x + x → 2 * x |
Reduce operation count |
| Inverse pairs | exp(log(x)) → x, neg(neg(x)) → x |
Cancel inverses |
| Trig identities | sin²(x) + cos²(x) → 1 |
Simplify trig expressions |
| Distributivity | a*(b+c) ↔ a*b + a*c |
Factor or expand as needed |
| Complex field | real(complex(r, i)) → r |
Simplify complex extraction |
| Comparison absorption | max(a, min(a, b)) → a |
Eliminate nested comparisons |
The cost model (OpCount) counts the number of operations in each equivalent expression, ensuring the optimizer always extracts a program that's equal or smaller than the original.
Transform composition is order-sensitive. grad(vmap(f)) produces different results from vmap(grad(f)). The dispatcher processes a stack of transforms against a Jaxpr:
- Strip leading
jit(compile-time annotation, no-op in V1) - Apply innermost transform first:
graduses symbolic tape-based AD;vmapuses the BatchTrace interpreter with per-primitive batching rules - Compose recursively: for
grad(vmap(f)), first vmap the Jaxpr, then grad the vectorized result - Thread effect tokens through execution to maintain side-effect ordering
- Record composition proof in the Trace Transform Ledger for auditability
The dispatcher also supports e-graph optimization as a compile option. When egraph_optimize=true, the Jaxpr is optimized via equality saturation before evaluation.
All linalg primitives are implemented in pure Rust with f64 arithmetic:
| Decomposition | Algorithm | Key Properties |
|---|---|---|
| Cholesky | Cholesky-Banachiewicz (row-by-row) | L where A = LL^T; requires SPD input |
| QR | Householder reflections | Q (orthogonal), R (upper triangular); sign-normalized diagonal |
| SVD | Jacobi rotations via A^T A eigendecomposition | U, S (descending), V^T; up to 100n^2 iterations |
| Eigh | Jacobi eigendecomposition | W (ascending eigenvalues), V (orthonormal eigenvectors) |
| TriangularSolve | Forward/back substitution | Exact for triangular systems; supports lower and upper |
The Cholesky VJP uses the formula bar_A = L^{-T} Phi(L^T bar_L) L^{-1} where Phi extracts the lower triangle with halved diagonal (Murray 2016). The JVP is dL = L Phi(L^{-1} dA L^{-T}). Both were numerically verified against finite-difference gradients.
The RNG implements the ThreeFry2x32 counter-based PRNG from Salmon et al. (SC'11: "Parallel Random Numbers: As Easy as 1, 2, 3"):
- Core cipher: 20 rounds of rotation + XOR + key injection on 2-word (64-bit) state, using Skein rotation constants [13, 15, 26, 6, 17, 29, 16, 24]
- Key splitting:
split(key) = [threefry(key, [0,0]), threefry(key, [0,1])], producing two statistically independent child keys - Fold-in:
fold_in(key, data) = threefry(key, [data, 0]), mixing external data into the key - Sampling: counter-based bit generation, then uniform (divide by 2^32), normal (Box-Muller), bernoulli (threshold), or categorical (Gumbel-max trick)
The deterministic design means random_key(42) always produces the same sequence, matching JAX's ThreeFry implementation. This is verified against 25 JAX oracle fixtures covering key generation, splitting, fold-in, uniform, and normal distributions.
The CPU backend parallelizes Jaxpr execution via dependency-wave scheduling:
Wave 1: a = f(x) b = g(x) <-- parallel (both depend only on input)
Wave 2: c = h(a, b) <-- sequential (depends on wave 1)
Wave 3: d = k(c) <-- sequential
The algorithm:
- Find all equations whose inputs are available (the "ready wave")
- Execute the wave in parallel via Rayon's thread pool
- Store outputs in the environment
- Repeat until all equations are executed
Barrier detection: equations with side effects, sub-Jaxprs, or multi-output primitives force sequential execution. This prevents reordering of effectful operations while maximizing parallelism for pure computations.
Every compilation/execution configuration gets a deterministic SHA-256 cache key:
fjx-<sha256hex> = SHA-256(
mode | backend | transforms | compile_options | custom_hook | jaxpr_fingerprint
)
The Jaxpr fingerprint recursively hashes the equation structure (primitives, arities, parameters, sub-Jaxprs). Transform ordering matters: grad,vmap and vmap,grad produce different keys. Compile options are sorted (BTreeMap) for deterministic ordering.
Strict mode rejects cache entries with unknown incompatible features. Hardened mode allows bounded recovery from unexpected cache states.
Long-lived artifacts (conformance fixtures, benchmark baselines, evidence ledgers) are protected against bit rot using RaptorQ erasure coding:
- Encode: split artifact into source symbols (256-byte chunks), generate 10% repair symbols
- Sidecar: JSON manifest with all symbols (Base64-encoded), SHA-256 hash, generation metadata
- Scrub: decode from sidecar symbols, verify SHA-256 match against original artifact
- Decode proof: intentionally drop N source symbols, verify recovery from remaining symbols + repair symbols
Any fixture bundle can therefore survive partial data loss and still be recovered, which matters in distributed CI environments where cache nodes can fail silently.
Five primitives produce multiple outputs: Cholesky (1 output but via multi-output path), QR (Q, R), SVD (U, S, Vt), Eigh (W, V), and TriangularSolve (X). These require special handling at every layer:
| Layer | Single-output | Multi-output |
|---|---|---|
| IR | outputs: [VarId(2)] |
outputs: [VarId(2), VarId(3)] or [VarId(2), VarId(3), VarId(4)] |
| Eval | eval_primitive() -> Value |
eval_primitive_multi() -> Vec<Value> |
| VJP | vjp(prim, inputs, [g]) -> [g_input] |
vjp(prim, inputs, [g1, g2], [out1, out2]) -> [g_input] |
| JVP | jvp_rule(prim, primals, tangents) -> tangent |
jvp_rule_multi(prim, primals, tangents, primal_outs) -> [t1, t2] |
| Vmap | Batch dim tracked on single output | Each output independently tracks its batch dim |
| Tape | Stores one output value | Stores all output values (needed for VJP of decompositions) |
The SVD VJP, for example, needs the primal U, S, and Vt outputs to compute dA from dU, dS, and dVt. This is why the tape records output_values; they're consumed during the backward pass.
FrankenJAX reimplements JAX's semantics, not its implementation. Key architectural differences:
| Aspect | JAX (Python/C++) | FrankenJAX (Rust) |
|---|---|---|
| IR | JAXPR with Python objects | Jaxpr with Rust enums (fully serializable) |
| Tracing | Python abstract interpreter | Rust closure tracing via TracerRef operator overloading |
| AD tape | Implicit via Python closures | Explicit Vec<TapeEntry> with value snapshots |
| Compilation | XLA HLO to device code | E-graph optimization then direct interpretation |
| Dispatch | C++ xla::Client |
Rust dispatch() with transform stack |
| Type system | NumPy dtype objects | Rust DType enum (11 variants) |
| Parallelism | XLA partitioning + SPMD | Rayon dependency-wave scheduling |
| Value model | NumPy arrays + JAX DeviceArray | Value::Scalar(Literal) / Value::Tensor(TensorValue) |
What we kept identical: the Jaxpr structure (invars, constvars, outvars, equations), primitive semantics, transform composition rules, VJP/JVP mathematical formulas, type promotion lattice, ThreeFry PRNG cipher, and control flow primitive semantics.
What we simplified: no XLA lowering (we interpret directly), no device placement (CPU only), no SPMD partitioning, no custom call mechanism, no platform-specific kernels.
What we added: Trace Transform Ledger (auditable composition proofs), RaptorQ durability (erasure-coded artifacts), decision/evidence ledger (loss-matrix runtime decisions), strict/hardened mode split, e-graph equality saturation optimizer.
make_jaxpr converts a Rust closure into a Jaxpr by running it with tracer values that record operations instead of computing them:
let jaxpr = make_jaxpr(|inputs| {
vec![&inputs[0] * &inputs[0] + &inputs[0]] // x^2 + x
}, vec![ShapedArray::scalar(DType::F64)])?;
// produces: Jaxpr: v2 = mul(v1, v1); v3 = add(v2, v1); output = v3Each TracerRef carries an abstract value (dtype + shape) and a reference to the shared trace context. Operator overloading on +, -, * routes through binary_op(Primitive::Add, ...) which records an Equation in the active trace frame rather than evaluating it.
Nested scopes are handled via a frame stack. Control flow primitives (cond, scan) push a sub-trace frame, trace the body closure in that scope, pop the frame to get a sub-Jaxpr, and record the control flow equation in the parent frame. This enables tracing through arbitrary nesting of jit(grad(vmap(f))) where each transform traces its body and wraps the result.
Shape inference runs during tracing to validate tensor shapes at trace time rather than at eval time. The infer_primitive_output_avals function handles shape inference for all 110 primitives, including complex cases like broadcasting, gather indices, and convolution output shapes.
When some inputs to a Jaxpr are known at compile time, partial evaluation splits the program into a known part (pre-computable) and an unknown part (runtime-dependent):
Original: y = (2 * 3) + x [2 and 3 are known, x is unknown]
Known: const_6 = 2 * 3 [pre-computed to 6]
Unknown: y = const_6 + x [residual program, simpler]
The algorithm classifies each equation as "known" (all inputs derivable from known values) or "unknown" (has abstract dependencies), then identifies residual values, meaning intermediate results produced by known equations but consumed by unknown ones. These become constants in the residual Jaxpr.
This is the mechanism that enables jit to specialize programs: static shapes, dtypes, and constant arguments are folded away, leaving only the dynamic computation.
vmap does not loop over the batch dimension. Each primitive has an O(1) batching rule that handles the batch dimension as metadata:
Elementwise operations (sin, exp, add, mul, ...): The batch dimension passes through unchanged. If one operand is batched and the other is not, the unbatched operand is broadcast.
vmap(sin)(x) where x.shape=[batch, n]
=> sin(x) # batch dim just passes through, no loop
Reductions (reduce_sum, reduce_max, ...): The reduction axis is shifted to account for the batch dimension. Reduction never operates along the batch axis itself.
vmap(reduce_sum, axis=1)(x) where x.shape=[batch, m, n]
=> reduce_sum(x, axis=2) # axis shifted past batch dim
=> result.shape=[batch, m]
Shape operations (reshape, transpose, ...): The batch dimension is excluded from the reshape/transpose operation. The permutation indices are adjusted to skip position 0 (the batch axis).
Multi-output operations (QR, SVD, ...): Each output independently tracks its batch dimension position. For QR of a batched matrix [batch, m, n], Q gets batch_dim=0 with shape [batch, m, k] and R gets batch_dim=0 with shape [batch, k, n].
Users can override the built-in VJP/JVP rules for any primitive:
fj_ad::register_custom_vjp(Primitive::MyOp, |inputs, grad, params| {
// Custom gradient logic
Ok(vec![custom_gradient])
});Custom rules are stored in a global thread-safe registry (RwLock<BTreeMap<Primitive, Arc<dyn Fn>>>). During AD, the engine checks for custom rules before falling back to built-in rules. This enables:
- Overriding gradients for numerically unstable primitives with stable alternatives
- Implementing "stop gradient" by registering a zero-returning rule
- Adding gradient rules for new user-defined primitives
- Testing AD correctness by comparing custom vs built-in rules
The clear_custom_derivative_rules() function resets all registrations, used in test isolation.
Higher-order derivatives are computed by composing first-order AD:
Jacobian (forward-mode, JVP-based):
For each basis vector e_j in input space:
tangent_j = jvp(f, x, e_j) # One forward pass per input dimension
J[:, j] = tangent_j # Column j of Jacobian
Cost: input_dim forward passes. Returns a [output_dim, input_dim] matrix.
Hessian (mixed-mode, grad + finite differences):
For each basis vector e_k in input space:
g_plus = grad(f)(x + eps*e_k) # Gradient at perturbed point
g_minus = grad(f)(x - eps*e_k)
H[:, k] = (g_plus - g_minus) / (2*eps) # Central difference
Cost: 2 × input_dim gradient evaluations. Returns a symmetric [input_dim, input_dim] matrix. Uses ε = 1e-5 for numerical stability.
The FFT primitives use a naive O(n^2) DFT (direct Fourier transform) rather than the O(n log n) Cooley-Tukey algorithm. This is a deliberate choice: correctness over speed for a reference implementation.
X[k] = sum_{j=0}^{n-1} x[j] * e^{-2*pi*i*j*k/n} (DFT)
x[j] = (1/n) sum_{k=0}^{n-1} X[k] * e^{+2*pi*i*j*k/n} (IDFT)
RFFT exploits Hermitian symmetry of real-valued input signals: only the first n/2 + 1 frequency bins are returned (the rest are conjugate mirrors). IRFFT reconstructs the full spectrum from the half-spectrum before applying IDFT.
All FFT operations support batching. The transform applies along the last axis, with leading dimensions treated as independent batch elements.
The conformance harness tracks parity at multiple granularities:
| Level | Tracks | Used For |
|---|---|---|
| Per-case | matched/mismatched, expected vs actual JSON, comparator type | Debugging individual failures |
| Per-family | Jit/Grad/Vmap/Lax/Random/ControlFlow/MixedDtype breakdowns | Identifying weak areas |
| Per-dtype | Tolerance tiers (F64: atol=1e-12; F32: atol=1e-5; Int: exact) | Precision-appropriate comparison |
| Suite-level | Overall pass rate, gate status (pass if rate=1.0) | CI gate decision |
| Drift classification | Pass/Regression/Improvement/Flake/Timeout | Trend tracking across releases |
The tolerance system uses dtype-aware tiers: double-precision results get tight 1e-12 tolerance, single-precision gets 1e-5, and integer/boolean results require exact match. This prevents false failures from legitimate floating-point variation while catching real regressions.
Binary operations in FrankenJAX support full NumPy-style broadcasting with four dispatch paths:
| LHS | RHS | Behavior |
|---|---|---|
| Scalar | Scalar | Direct operation |
| Scalar | Tensor | Broadcast scalar across tensor shape |
| Tensor | Scalar | Broadcast scalar across tensor shape |
| Tensor | Tensor (same shape) | Elementwise operation |
| Tensor | Tensor (different shape) | Multi-dimensional broadcasting |
Multi-dimensional broadcasting follows NumPy rules: shapes are right-aligned, dimensions of size 1 are stretched, and incompatible dimensions cause errors. For example, a [3, 1] tensor can broadcast with a [1, 4] tensor to produce a [3, 4] result.
Complex number arithmetic is handled separately with proper (a+bi)(c+di) = (ac-bd) + (ad+bc)i multiplication and (a+bi)/(c+di) conjugate-denominator division.
FrankenJAX defends against several threat categories relevant to ML infrastructure:
Cache confusion attacks: Malicious or corrupted cache entries could cause a program to silently produce wrong results. The SHA-256 fingerprinting system binds cache keys to the exact computation (Jaxpr structure + transforms + mode + backend), and Strict mode rejects any unknown features.
Transform-order vulnerabilities: grad(vmap(f)) and vmap(grad(f)) produce different results. The Trace Transform Ledger records the exact transform composition order, preventing silent reordering. The dispatcher verifies composition compatibility before execution.
Malformed graph attacks: Adversarial or corrupted Jaxpr graphs could trigger panics or undefined behavior. All graph traversal validates arities, shapes, and dtypes at each equation. The fuzz testing infrastructure (cargo fuzz) continuously tests the IR deserializer and evaluator against malformed inputs.
Silent data corruption: Conformance fixtures and benchmark baselines could be corrupted on disk or in transit. RaptorQ sidecars provide erasure coding that detects and recovers from partial data loss, with decode proofs that verify recovery correctness.
FrankenJAX validates against real JAX output, not just hand-computed expected values:
# Set up JAX environment
uv venv --python 3.12 .venv
uv pip install --python .venv/bin/python jax jaxlib numpy
# Capture oracle fixtures from JAX (strict mode = real JAX, no fallback)
.venv/bin/python crates/fj-conformance/scripts/capture_legacy_fixtures.py \
--legacy-root legacy_jax_code/jax \
--output crates/fj-conformance/fixtures/transforms/legacy_transform_cases.v1.json \
--rng-output crates/fj-conformance/fixtures/rng/rng_determinism.v1.json \
--strict| Fixture Family | Cases | Source |
|---|---|---|
| Transform (jit/grad/vmap/lax/control_flow/mixed_dtype) | 611 | JAX 0.9.2 |
| RNG determinism (key/split/fold_in/uniform/normal) | 25 | JAX 0.9.2 |
| Linear algebra + FFT oracle | 33 | JAX 0.9.2 (x64) |
| Transform composition (jit+grad, grad+grad, vmap+grad, jacobian, hessian) | 15 | JAX 0.9.2 (x64) |
| Dtype promotion (9x9 dtype matrix, add + mul) | 162 | JAX 0.9.2 (x64) |
| Total | 846 |
# Clone the repository
git clone https://github.com/Dicklesworthstone/frankenjax.git
cd frankenjax
# Rust nightly is required (see rust-toolchain.toml)
rustup install nightly
rustup override set nightly
# Build the entire workspace
cargo build --workspace
# Run all tests
cargo test --workspace
# Build with release optimizations (LTO + single codegen unit)
cargo build --workspace --releaseOptional: Set up JAX oracle environment for conformance capture:
uv venv --python 3.12 .venv
uv pip install --python .venv/bin/python jax jaxlib numpy
git clone --depth 1 https://github.com/jax-ml/jax.git legacy_jax_code/jaxEach E2E scenario produces a forensic JSON log at artifacts/e2e/<scenario>.e2e.json containing:
{
"scenario": "e2e_p2c001_full_dispatch_pipeline",
"replay_command": "cargo test -p fj-conformance --test e2e ...",
"input_capture": { "args": [...], "transforms": [...] },
"intermediate_states": {
"traced_jaxpr": { "equations": [...] },
"optimized_jaxpr": { "equations": [...] },
"ledger_entries": [...]
},
"output_capture": { "values": [...], "dtype": "F64" },
"timing": { "trace_ms": 0.5, "dispatch_ms": 1.2, "eval_ms": 0.8 }
}This enables full replay: given the input capture, you can re-run the exact same computation and verify the output matches. The intermediate states show the Jaxpr before and after optimization, and the ledger entries show every decision the dispatcher made. Timing breakdowns separate trace, dispatch, and evaluation phases for performance analysis.
# Format check
cargo fmt --check
# Compiler + lint
cargo check --all-targets
cargo clippy --all-targets -- -D warnings
# Full test suite
cargo test --workspace
# Conformance tests with output
cargo test -p fj-conformance -- --nocapture
# Benchmarks
cargo bench# Run all E2E scenarios
./scripts/run_e2e.sh
# Run one scenario
./scripts/run_e2e.sh --scenario e2e_p2c001_full_dispatch_pipeline
# Each scenario emits forensic logs at artifacts/e2e/<scenario>.e2e.json# Full gates (coverage + flake + runtime + crash triage)
./scripts/enforce_quality_gates.sh
# Local iteration (fast)
./scripts/enforce_quality_gates.sh --skip-coverage --flake-runs 3
# Flake detector standalone
./scripts/detect_flakes.sh --runs 10All long-lived artifacts get RaptorQ erasure-coded sidecars:
# Generate sidecar + scrub + decode proof (all-in-one)
cargo run -p fj-conformance --bin fj_durability -- \
pipeline --artifact <path> --sidecar <sidecar_path> \
--report <scrub_report_path> --proof <decode_proof_path>
# Batch process a directory
cargo run -p fj-conformance --bin fj_durability -- \
batch --dir artifacts/e2e --output artifacts/durability --jsoncd crates/fj-conformance/fuzz
cargo fuzz build
cargo fuzz run ir_deserializer corpus/seed/ir_deserializerThe 110 operations in the Primitive enum span the full breadth of JAX's LAX API:
| Category | Primitives | Count |
|---|---|---|
| Arithmetic | Add, Sub, Mul, Div, Neg, Abs, Rem, Pow, Max, Min, Exp, Log, Sqrt, Rsqrt, Floor, Ceil, Round, Expm1, Log1p, Sign, Square, Reciprocal, Logistic | 23 |
| Trigonometric | Sin, Cos, Tan, Asin, Acos, Atan, Atan2 | 7 |
| Hyperbolic | Sinh, Cosh, Tanh | 3 |
| Special math | Erf, Erfc, Cbrt, Lgamma, Digamma, ErfInv, IsFinite, IntegerPow, Nextafter | 9 |
| Complex | Complex, Conj, Real, Imag | 4 |
| Comparison | Eq, Ne, Lt, Le, Gt, Ge | 6 |
| Reduction | ReduceSum, ReduceMax, ReduceMin, ReduceProd, ReduceAnd, ReduceOr, ReduceXor, ReduceWindow | 8 |
| Shape manipulation | Reshape, Transpose, BroadcastInDim, Slice, DynamicSlice, DynamicUpdateSlice, Gather, Scatter, Concatenate, Pad, Rev, Squeeze, Split, ExpandDims | 14 |
| Linear algebra | Cholesky, QR, SVD, TriangularSolve, Eigh | 5 |
| FFT | Fft, Ifft, Rfft, Irfft | 4 |
| Bitwise | BitwiseAnd, BitwiseOr, BitwiseXor, BitwiseNot, ShiftLeft, ShiftRightArithmetic, ShiftRightLogical, PopulationCount, CountLeadingZeros | 9 |
| Control flow | Cond, Scan, While, Switch | 4 |
| Other | Dot, Select, Clamp, Iota, BroadcastedIota, Copy, BitcastConvertType, ReducePrecision, OneHot, Cumsum, Cumprod, Sort, Argsort, Conv | 14 |
Every primitive has:
- Full evaluation in
fj-lax(scalar and tensor paths) - VJP rule (reverse-mode gradient) in
fj-ad - JVP rule (forward-mode tangent) in
fj-ad - NumPy-style broadcasting for binary operations
FrankenJAX includes a library of 145+ pre-defined test programs (ProgramSpec enum) that serve as standardized inputs for conformance testing, benchmarking, and AD verification:
| Category | Examples | Count |
|---|---|---|
| Basic arithmetic | Add2, Square, SquarePlusLinear, AddOne |
~10 |
| Unary LAX ops | LaxNeg, LaxAbs, LaxExp, LaxLog, LaxSqrt, LaxSin, LaxCos, ... |
~30 |
| Binary LAX ops | LaxAdd, LaxMul, LaxDiv, LaxPow, LaxMax, LaxMin, ... |
~15 |
| Special math | LaxCbrt, LaxLgamma, LaxDigamma, LaxErfInv, LaxIsFinite |
~10 |
| Shape manipulation | LaxReshape6To2x3, LaxTranspose2x3, LaxSlice1To4, LaxRev, ... |
~15 |
| Linear algebra | LaxCholesky, LaxQr, LaxSvd, LaxEigh, LaxTriangularSolve |
5 |
| FFT | LaxFft, LaxIfft, LaxRfft, LaxIrfft |
4 |
| Control flow | CondSelect, ScanAdd, LaxWhileAddLt, LaxSwitch3 |
~5 |
| Bitwise | LaxBitwiseAnd, LaxShiftLeft, LaxPopulationCount, ... |
~10 |
| Cumulative/Sort | LaxCumsum, LaxCumprod, LaxSort, LaxArgsort |
4 |
| Reductions | LaxReduceSum, LaxReduceMax, LaxReduceWindow |
~5 |
| Other | LaxIota5, LaxOneHot4, LaxConv1dValid, LaxCopy |
~30 |
Each program is constructed via build_program(spec) -> Jaxpr, returning a ready-to-evaluate Jaxpr. This enables systematic testing: run every program through jit, grad, vmap, and composed transforms, comparing against oracle values.
FrankenJAX is built on carefully chosen Rust crates:
| Dependency | Purpose | Why This Crate |
|---|---|---|
egg |
E-graph equality saturation | Industry-standard for term rewriting; extensible language and cost model |
smallvec |
Inline small-vector for IR nodes | Avoids heap allocation for typical 1-4 element equation inputs/outputs |
rustc-hash |
Fast non-crypto hashing | FxHashMap for interpreter environments; faster than SipHash for small keys |
serde + serde_json |
Serialization | IR nodes, fixtures, ledger entries all need deterministic serialization |
sha2 |
Cryptographic hashing | Cache keys and durability sidecars require collision-resistant hashing |
proptest |
Property-based testing | Generates random inputs to verify algebraic invariants (commutativity, associativity, etc.) |
criterion |
Microbenchmarking | Statistically rigorous benchmarks for dispatch and interpreter hot paths |
rayon |
Data parallelism | Dependency-wave executor parallelizes independent equations |
half |
Half-precision floats | BF16 and F16 scalar representation |
base64 |
Encoding | RaptorQ sidecar symbols encoded as Base64 for JSON storage |
tempfile |
Temporary directories | Test isolation for conformance fixture and durability pipeline tests |
Feature flags:
[features]
default = []
asupersync-integration = ["dep:asupersync"] # Structured async runtime
frankentui-integration = ["dep:ftui"] # Terminal UI renderingBuild profile:
[profile.release]
opt-level = 3 # Maximum optimization
lto = true # Link-time optimization
codegen-units = 1 # Single codegen for best optimization
strip = true # Remove debug symbolsThe workspace uses Rust edition 2024 (nightly) and enforces #![forbid(unsafe_code)] in every crate except fj-ffi (which provides the C FFI boundary).
Several primitives use polynomial or series approximations rather than calling libm:
| Function | Algorithm | Accuracy |
|---|---|---|
erf(x) |
Horner-form rational approximation (Abramowitz & Stegun 7.1.26) | ~1e-7 relative error |
lgamma(x) |
Lanczos approximation with 7 coefficients, reflection formula for x < 0.5 | ~1e-10 |
digamma(x) |
Asymptotic series with recurrence reduction to x > 6 | ~1e-8 |
erf_inv(x) |
Rational minimax approximation with tail correction | ~1e-6 |
trigamma(x) |
Asymptotic expansion with recurrence | ~1e-8 |
These approximations avoid platform-dependent libm behavior, ensuring identical results across operating systems. The tolerances are verified against reference values in the test suite.
Each crate defines its own error type, and errors propagate upward through the stack:
EvalError (fj-lax) --+
InterpreterError --+
AdError (fj-ad) --+--> ApiError (fj-api)
DispatchError --+
DurabilityError --+
Error types are enums with structured variants (not string messages), so callers can match on specific failure modes:
enum EvalError {
ArityMismatch { primitive: Primitive, expected: usize, actual: usize },
ShapeMismatch { primitive: Primitive, left: Shape, right: Shape },
TypeMismatch { primitive: Primitive, detail: &'static str },
Unsupported { primitive: Primitive, detail: String },
}No panics in library code. All error paths return Result. The only unwrap() calls are in test code.
FrankenJAX uses several guards against numerical issues:
- Degenerate eigenvalue handling: SVD and Eigh VJP rules check
|σ_i - σ_j| > 1e-20before dividing by eigenvalue gaps. Near-degenerate eigenvalues produce zero gradient contributions instead of infinities. - Cholesky diagonal guard: The Cholesky decomposition checks for non-positive diagonal elements, which indicate a non-SPD input matrix.
- Division-by-zero protection: The
divprimitive produces NaN for 0/0 and Inf for x/0, consistent with IEEE 754. - Triangular solve stability: Forward/back substitution checks diagonal elements against machine epsilon before dividing.
- FFT scaling: IFFT uses exact
1/nscaling, and RFFT interior bins get the correct factor-of-2 for real-signal Hermitian symmetry.
Jaxpr equations can carry effect tokens that enforce sequential execution ordering for side-effectful operations. The effect system ensures that:
- Effectful equations execute in program order (never reordered by the parallel executor)
- The dependency-wave scheduler treats effectful equations as barriers
- Transform composition preserves effect ordering (grad cannot reorder effects)
- Sub-Jaxprs in control flow inherit their parent's effect context
Currently used for: RNG state threading (ensuring deterministic random sequences), and as extension points for future I/O or mutation effects.
The Gather and Scatter primitives implement generalized indexing:
Gather: Extract slices from a tensor at computed index positions. Given an operand [m, n] and indices [k, index_depth], produces [k, slice_size] by looking up each index position. Supports multi-dimensional operands and configurable slice sizes.
Scatter: The inverse of gather. Updates a tensor at computed index positions. Given a base tensor, indices, and update values, produces a new tensor with updates applied at the specified positions. Supports configurable update modes (overwrite, add-to-existing).
Both primitives have VJP and JVP rules. Gather's VJP produces a scatter (adjoint of indexing is scattering gradients back), and Scatter's VJP produces a gather (adjoint of scattering is gathering the relevant gradients).
- CPU-only backend. GPU/TPU backends are not yet implemented. The CPU backend uses rayon for wave-parallel execution.
- No F32 scalar literals. All float scalars are stored as F64 internally. F32 tensor operations work via TensorValue, but scalar-level F32 promotion differs from JAX.
- Bool arithmetic. FrankenJAX does not treat Bool as numeric in arithmetic operations (JAX does:
True + True = 2). - No XLA lowering. FrankenJAX evaluates through its own interpreter, not through XLA. This means we match JAX's mathematical semantics but not its compilation/optimization pipeline.
- Partial
vmap+ control flow composition.vmap(scan(...))and similar compositions need further work.
Q: Why not just use JAX directly? A: JAX requires Python + XLA + CUDA/TPU. FrankenJAX gives you the mathematical transform semantics in a standalone Rust library with no Python dependency.
Q: How do you verify correctness without running JAX? A: We capture oracle fixtures from real JAX 0.9.2 (834 cases), then run our Rust implementation against those fixtures in CI. Every primitive, every transform, every dtype combination.
Q: Is the AD (automatic differentiation) complete? A: Yes. All 110 primitives have both VJP (reverse-mode) and JVP (forward-mode) rules, including complex operations like Cholesky, QR, SVD, Eigh decompositions and FFT. Numerical verification tests confirm correctness via finite-difference comparison.
Q: What's the Trace Transform Ledger?
A: Every transform composition (jit(grad(f)), vmap(grad(f)), etc.) produces a verifiable proof artifact that records the input IR, applied transforms, and output IR. This makes transform composition auditable.
Q: How fast is it? A: Performance optimization is ongoing. The CPU backend uses a dependency-wave parallel executor, and the e-graph optimizer applies 87 algebraic simplification rules. Profiling infrastructure is in place but systematic optimization hasn't started yet.
Q: What's the difference between Strict and Hardened mode? A: Strict mode refuses to process anything it does not fully understand. Unknown features, incompatible cache entries, or ambiguous inputs cause hard failures. Hardened mode allows bounded defensive recovery: it can handle some malformed inputs and degrade gracefully, but logs every recovery action in the decision ledger for audit.
Q: How does the e-graph optimizer work?
A: It converts your Jaxpr into an equivalence graph where every algebraically equivalent form exists simultaneously (e.g., x+x and 2*x coexist as equivalent). The 87 rewrite rules fire until saturation (no new equivalences found), then a cost function extracts the cheapest program. This can discover simplifications that sequential rule application would miss.
Q: Can I use FrankenJAX from Python/C?
A: The fj-ffi crate provides a C FFI surface for calling FrankenJAX from any language with C interop. Python bindings are not yet implemented but would be straightforward via PyO3 or cffi on top of fj-ffi.
Q: How are the linalg AD rules verified? A: Every linalg VJP and JVP rule is verified two ways: (1) numerical finite-difference comparison (perturb inputs, compare analytical vs numerical gradient), and (2) oracle comparison against JAX's output with x64 precision enabled. This caught two real bugs in the Cholesky decomposition AD during development: a missing diagonal-halving factor and a wrong triangular-solve direction.
Q: What's RaptorQ and why is it in a math library? A: RaptorQ is a fountain code (erasure code) that can reconstruct data from any sufficient subset of encoded symbols. We use it to protect conformance fixtures and benchmark baselines against silent data corruption. In a distributed CI setup where cache nodes can lose data, this ensures your test evidence survives. Unusual for a math library, yes, but correctness evidence that cannot be trusted is worthless.
Q: How does the RNG match JAX?
A: FrankenJAX implements the exact same ThreeFry2x32 cipher with the same rotation constants, key schedule, and sampling algorithms (Box-Muller for normal, Gumbel-max for categorical). The determinism is verified against 25 JAX oracle fixtures. random_key(42) produces bit-identical output to JAX's jax.random.key(42).
| Document | Purpose |
|---|---|
AGENTS.md |
AI agent development guidelines |
FEATURE_PARITY.md |
Feature-by-feature JAX parity status (all green) |
COMPREHENSIVE_SPEC_FOR_FRANKENJAX_V1.md |
Full V1 specification |
PLAN_TO_PORT_JAX_TO_RUST.md |
Original porting strategy |
EXISTING_JAX_STRUCTURE.md |
JAX architecture analysis |
PROPOSED_ARCHITECTURE.md |
FrankenJAX design decisions |
Please don't take this the wrong way, but I do not accept outside contributions for any of my projects. I simply don't have the mental bandwidth to review anything, and it's my name on the thing, so I'm responsible for any problems it causes; thus, the risk-reward is highly asymmetric from my perspective. I'd also have to worry about other "stakeholders," which seems unwise for tools I mostly make for myself for free. Feel free to submit issues, and even PRs if you want to illustrate a proposed fix, but know I won't merge them directly. Instead, I'll have Claude or Codex review submissions via gh and independently decide whether and how to address them. Bug reports in particular are welcome. Sorry if this offends, but I want to avoid wasted time and hurt feelings. I understand this isn't in sync with the prevailing open-source ethos that seeks community contributions, but it's the only way I can move at this velocity and keep my sanity.