Vin Howe

Twitter (DM me) · Google Scholar · GitHub · LinkedIn · vin @ this domain

You probably shouldn’t train a language model in your browser—here’s how

You probably shouldn’t train a language model in your browser—here’s how

Blog post · November 21, 2025

For better or worse, I wrote this post myself, em-dashes 1 and all.

I just released Sequence Toy, a playground for training language models entirely in your browser. I also built Piston, a proof-of-concept2 WebGPU deep learning library modeled after PyTorch, for the sole purpose of powering Sequence Toy. This blog post is a vehicle for me to talk about the work that went into the combined project by way of a narrative device: what will you want to know when you inevitably build yourtorch, your own little WebGPU PyTorch clone? I’ll start by pointing out how impractical this is and maybe dissuade you in the process.

Some important notes:

  • You’re welcome to skip the first section if you’d like to go straight to technical details.
  • This post does not function as a working tutorial, only a roadmap. If you decide to take it literally and refer to this post as you build something, you will need to be able to fill in some gaps with your own research. I also recommend reading it all the way through first, or you’ll end up redoing a few things. You’re also welcome to DM me!

Contents

Introduction: this exercise almost certainly fails to justify itself

Piston is not the first to implement machine learning or deep learning on the web. In 2013, ConvNetJS by Andrej Karpathy was probably the first to train small models in pure JavaScript—keep in mind that before AlexNet came out the year prior, neural networks hadn’t taken off, so this was in the primordial soup of the ongoing AI boom. Three years later, inspired by ConvNetJS, we got A Neural Network Playground, but it wasn’t until 2018 that TensorFlow.js managed to achieve some limited GPU acceleration by repurposing WebGL shaders as compute kernels. In 2023, Frank Krueger built webgpu-torch, the project most similar to Piston that I’m aware of, so I’m not the first to build something PyTorchy on top of WebGPU. But, to my knowledge, Piston is the first project that combines enough compute shaders and WebGPU performance considerations in one place to train something as involved as language models, albeit at a small scale, in a web browser.

Still: I’d have to imagine that for most rational agents, being a pioneer isn’t enough of a reason to do something like this. You should ideally be able to pay back the opportunity cost you spend working on yourtorch by unlocking some amount of practical value, or you end up creating an impressive but purely technical curiosity—like DOOM in TypeScript types—instead. The bad news, then, is that the browser is a punishingly impractical place to train a language model. Conversational language models capable enough to use every day cost on the order of 100 million USD, at least all the way back in 2022, to train across 20K100K GPUs—primarily the preserve of only a handful of AI labs—and that was before extensive RL post-training3 became an expectation. To put into perspective the computational difference between training and inference, consider a rule of thumb introduced in 2023 which put the FLOPs of a single inference at roughly the square root of that required for training—for a model that requires 101210^{12} FLOPs for inference, the difference in scale is roughly what 31,688 years is to a second.

Some quick math to make this feel more concrete in the browser: the smallest of the GPT-2 family, released in 2019, was impressive at the time for spouting grammatical text matching the broad style of the input prompt—nonsense nonetheless—after being trained on 40GB, or ~21 billion tokens 4, of internet text. We know precious little about the scale of the GPT-5s of our day, but the recent-as-of-publication open-source Qwen3 family of models, the largest of which has 235 billion total parameters, were trained on 36 trillion tokens of text. Simplistically ignoring differences in MOE sparsity and tokenizer vocabulary size, the difference in both parameters and data is roughly 3 orders of magnitude. With Piston, the biggest autoregressive transformer I have successfully fit on the GPU in my 16GB M1 Pro is ~50m parameters, and I couldn’t convince WebGPU to use more GPUs even if I had them5. At roughly a step per second on my laptop, with a batch size of 1 and a context length of 32, it just so happens that I could manage about a billion tokens a year. So it would take 21 years to make my poor distilgpt2-wannabe choke down the 21 billion tokens of internet they used to make 2019’s best generative model, or slightly more than 31,709 years to show it Qwen3’s curriculum. This should help explain why most in-browser AI investment goes to inference, like transformers.js and WebLLM, and why practically all browser-training projects are demos and toys.

Finally, as if all that weren’t enough, this is reasonably involved technical work! If the web platform looked more like the deep learning ecosystem, and especially if WebGPU looked more like CUDA (I’ll discuss this later) the activation energy associated with porting over PyTorch might be low enough to ignore the impracticality of the thing. But given the number of web-specific considerations you’ll need to implement yourtorch (even with this blog post under your belt) it would be sane of you to consider what else you could do with a similar investment of time. In my case, for Sequence Toy and Piston, I’m forced to admit that the value was mostly educational. The same way someone might implement a compiler, backprop, tokenizers, or GPT-2 from scratch to understand what’s going on under the hood, I implemented generic encoder, decoder, and encoder-decoder transformer architectures and everything needed to train them on a single GPU6. The added value of doing it on the web with WebGPU, despite all of the ways its design choices are a mismatch for deep learning, is that the work I do is not quite as redundant as if, say, I implemented a C++ (or Rust) deep learning framework with CUDA support. Plus, I can build a fun demo, write a blog post about it, and claim to be first to something. There are plenty of ways to outdo me when you build yourtorch: maybe figure out how to distribute training across browsers in a peer-to-peer way using DiLoCo—that might be the only way you can convince WebGPU to use multiple GPUs! Or engineer an intermediate representation that you can then do optimizations over to achieve things like operator fusion, which we’ll talk about later.

If four paragraphs wasn’t enough to deter you, here’s a high-level overview of how to build yourtorch, along with some hard-earned considerations.

First, you’ll want a tensor

Given that Piston borrows heavily from PyTorch’s choices of abstraction, you might find ezyang’s blog on PyTorch internals to be a useful supplementary resource, especially if your goal is to understand PyTorch better specifically.

The tensor is the core primitive of the modern deep learning framework, in the same way that the ndarray is NumPy’s central load-bearing data structure; if you’re already the sort of person who thinks about building yourtorch, this has probably occurred to you. But it’s worth being very clear about what a tensor is, because before anything else, you will build a tensor. So, to recap: a tensor object describes an nn-dimensional array. This is represented as a handle to a data buffer on a device—usually either the CPU or a GPU—along with the metadata you’ll need to interpret that data as an array. Here’s the minimum set of such metadata you’ll need for yourtorch.Tensor:

  • sizes: an array specifying the tensor’s dimensions.
  • strides: an array the same length as sizes, specifying how the tensor is laid out in memory—see ezyang’s section on strides for an introduction.
  • dtype: the data type associated with that tensor. In Piston7 I have f16 and f32, which represent 16- and 32-bit floating point types, and i32, a 32-bit integer type. PyTorch implements many more.
  • device: whether the tensor is in WebGPU memory or in a regular memory buffer, on the CPU.

So far, except for the bit about devices, this is very close to how you might expect to see an ndarray defined, if you’re familiar with NumPy. What most cleanly distinguishes these two array types is that autodifferentiation (and, as we’ll discuss later, graph execution) requires us to keep track of the graph of operations performed on tensors. For example, if we want to be able to compute the gradient of tensor1.add(tensor2), its tensor result should record which operation created it—addition—and its inputs—tensor1 and tensor2. This allows PyTorch, Piston, and eventually yourtorch to create execution graphs where each tensor object contains both its result node and, by way of the operation metadata it tracks, graph edges from its inputs. Leaf nodes are the base case here—for instance, in our addition example, tensor1 and tensor2 aren’t the result of another operation, so they both record a constant-operation with no inputs, where their data was probably supplied by some factory function (e.g. yourtorch.{zeros,full,randn,...}).

Your first order of business after implementing a basic tensor API is to define WebGPU compute shaders in WGSL for every operation used in the forward pass of the neural network architecture you’re targeting. My original goal was a faithful implementation of minGPT, because I could take a common set of inputs—anything, really—and compare its forward pass with mine, operation by operation, to make sure my shaders checked out numerically. I ended up needing 25 unique shaders for my decoder-only forward pass, out of 81 overall8. You might find that now is a convenient time to implement a nn.Module analog—you could even refer to module.py.

But hang on, don’t start writing shaders just yet. Allow me to point out a few logical groupings:

  • Many operations, like sin, log, square root, and ReLU, are simply element-wise functions of the input. These are unary operations.
  • Addition, subtraction, multiplication, and division are all binary operations: two inputs and an output, all with the same data type.
  • Comparison operations take same-data-type inputs and nominally produce boolean-valued output, but we settle for 32-bit integers. This includes equality, inequality, greater than, and less than—along with their respective non-strict variants.
  • If you get this far, you’ll discover that sum, minimum, maximum, argmin, and argmax are all structurally quite similar. These all perform parallel reduction, so they’re called reduce operations.

You’ll have a better time if you write code that generates shaders instead of trying to write them manually. For example, a generator for all unary shaders is a natural fit because everything but the element-wise function is the same.

Finally, it is a matter of grave importance that you write tests for your kernels. For each kernel you write without testing, you will accrue a debt with Entropy, and deep neural networks will mysteriously refuse to converge for you. Spare yourself this spirit-corroding experience and make a plan to write tests from kernel one. One particularly convenient way to do this, assuming the operations you write are also in PyTorch, is to effectively write the tests in Python: you can generate and serialize (test case label, input, output, acceptable error, operation) tuples ahead of time and wire up your tests to simply check that for each of these cases, the output of your framework is within an acceptable error of PyTorch’s. Once all these tests pass, you should technically have everything you’ll need to run a forward pass. Now let’s talk about implementing backpropagation.

Backward-mode automatic differentiation, or backpropagation

Considering the many consummate expositions on backpropagation available at this point, I’ll skip over its details, but if you feel shaky about this AI fundamental, follow the footnotes 9 .

As yourtorch’s future author, you will need to implement the backprop algorithm and either (1) write a kernel to compute the gradient for each differentiable Tensor operation you define (i.e. create an addition_backward kernel in addition to addition), or (2) compute each gradient with a graph of tensor operations mathematically equivalent to the kernel you would’ve written. The upside of the latter is that it’s easier to implement and test because many gradients can be composed from the operations you already wrote for your forward pass, with the downside that computing a gradient using multiple kernels requires intermediate trips to memory, which is bad because we’re primarily constrained by memory, not compute. I went the easy route with Piston, partly because HuggingFace’s Candle project does (candle/candle-core/src/backprop.rs), which allowed me to use a lot of their code as a starting point. Some concrete implementation notes here:

  1. Just like your tensor operations, make sure you have a way to compare gradients with PyTorch.
  2. If you’re building a WebAssembly backend in Rust or C++ and composing gradients out of other tensors, it will save you some grief to define these gradient functions within that backend. This is because of how garbage collection works in JavaScript, which we’ll talk about later.

If you tried to implement yourtorch right now, with everything we’ve discussed so far, you might find that there’s an impedance mismatch between the synchronous nature of your nascent PyTorch, and the asynchronous queue-driven way WebGPU sees the world. We’ll start chipping away at this in the next section.

An abbreviated introduction to graph-based execution

Some table-setting: the eager execution strategy is exactly what you might expect. It’s what PyTorch does: when you call y = x * w + b, it immediately evaluates z = x * w, then y = z + b—just like a calculator would. This might seem too obvious to bother pointing out. But it turns out that there’s another execution strategy, graph execution, where the entire program is submitted for evaluation at once. Let’s examine why you’ll want to consider graph execution for yourtorch.

Eager execution has some problems. For one, there’s a small overhead associated with submitting and waiting for individual GPU kernels, which can compound over the many, many such invocations done in the service of industrial-scale modern deep learning. Nvidia addresses this on their GPUs with CUDA Graphs, which lets you significantly reduce this overhead by instead submitting a full execution graph of many CUDA programs all at once. XLA takes this a step further and does low-level optimizations like operator fusion to the graph before shipping it off to the TPU, illustrating another advantage of graph execution. Graph-level optimizations turn out to be quite important because most of modern deep learning is memory-bound, so a smaller set of fused shaders means fewer expensive trips to high-bandwidth memory. Christopher Fleetwood has an excellent and digestible analysis that deals with this sort of performance consideration from the perspective of optimizing Transformer inference.

In Piston, I didn’t implement any form of graph optimization10, as hard as I looked for an excuse to be sniped by the problem, but I did adopt a graph execution model. At a high level, here’s why: WebGPU’s interface for submitting programs to the GPU is GPUQueue, which accepts a full graph at once. This makes it more like CUDAGraph and XLA11 than PyTorch/CUDA, and it is asynchronous. You’d find that overhead would add up fast if, for every single tensor operation in your graph, you submitted a new queue and awaited its result before proceeding. So let’s talk about how to add graph execution to yourtorch, in two stages.

Baby’s first inference-focused tensor API

For pedagogical and rhetorical effect, we’ll pretend to implement graph execution twice: first, as if we cared only about inference, and then again with an eye toward training. If you don’t particularly care for the extra intuition, you can skim this section and jump to the next one.

We’ll use the inference-focused Ratchet library as the blueprint for this section. This is for three reasons:

  1. It was built from the ground up with WebGPU’s execution model in mind, via wgpu. See ratchet/ARCHITECTURE.md#Design Decisions.
  2. It is not nearly as large as something like PyTorch, and so even advanced functionality is often in only one or two files, which makes it a pedagogical godsend.
  3. Piston is a hard fork of Ratchet, so I have a thing for it.

At this point in yourtorch’s development, you know you’ll want a graph-based execution model that plays well with WebGPU. I’ll also tell you for free that buffer allocation is quite slow in WebGPU, so you’ll want to reuse allocated buffers as much as possible. This is an important performance consideration that I won’t really cover in this blog post; see ratchet/crates/ratchet-core/src/gpu/buffer_allocator/allocator.rs and ratchet/crates/ratchet-core/src/gpu/pools for details on how Christopher did it—with only minor modifications, this is what I did too. I think, given enough time and tasked with creating an API like PyTorch with WebGPU as a first-class target, you might independently come up with something a lot like Ratchet:

import { Module, Device, zeros, randn, gpu, cpu } from 'yourtorch';

class SimpleLinear extends Module {
    constructor(inFeatures: number, outFeatures: number, device: Device) {
        super();
        this.weight = zeros([outFeatures, inFeatures], { device });
        this.bias = zeros([outFeatures], { device })
    }

    schedule(tensor: Tensor) {
        return input.matmul(this.weight.T).add(this.bias)
    }
}

const inputs = randn([1], { device: gpu });
const linear = SimpleLinear(1, 1, gpu);

(await linear.schedule(inputs).resolve()).to(cpu).item()

What this would do behind the scenes, when you call resolve(), is:

  1. Create a post-order traversal of the tensor-operations—we choose post-order because it outputs a list where each node is inserted only after all of its children, giving us the important property that a given node is always calculated after its dependencies.
  2. Allocate a bunch of buffers from a pool, like we discussed briefly, so they can be reused later.
  3. Translate the operations into a WebGPU queue and submit it (leaving out a lot of details here).
  4. Await a mapAsync call on an output buffer.
  5. Associate that buffer with the output tensor.

You’ll need some sort of device-transfer affordance, like to(device), to move it between the GPU and CPU, where you can actually use its output.

This story is great for inference: usually, you only want a single output, like a token ID. Or if you have a few outputs, it’s not much more expensive to run resolve() a few more times, because you’ll have computed most of the model in the first pass, so you’ll only spend a little extra time repeatedly submitting queues and traversing a mostly identical graph. But in backprop, the “output” of a forward-backward pass includes tens to hundreds of gradients, so the overhead of naively resolve()-ing each of them separately adds up quickly.

LazyTensor: graph-based execution for training

Let’s start by defining some criteria for our solution. It should:

  • Only require a single overall post-order—we’d prefer it if we could translate the forward pass, loss computation, gradient computations, and optimizer update into a single queue of WebGPU commands.
  • Require minimal or no changes to module and optimizer neural network code.

The first place I looked when trying to solve this problem for Piston was any XLA support for PyTorch, on the hunch that despite the eager-graph mismatch, Google would be determined to let developers use PyTorch on its TPUs. As it turns out, torch_xla implements something called LazyTensor, which happens to meet all our requirements.

Here’s how this differs from the first implementation:

  1. In the constructor and destructor of your Tensor type, add code to respectively register/deregister it with singleton LazyTensor state. In the paper, this is the DeviceContextArena.
  2. Create an asynchronous mark_step method accessible at least to your optimizer/module code, but preferably globally, that creates a combined, deduplicated post-order for all registered tensors and executes it. I’ll belabor here that by default, nothing should run until you call mark_step—it’s all deferred.
  3. Call mark_step near the end of the step() method for all optimizers you define.
  4. Update your tensor code to resolve its post-order and compute whenever you request the value of the tensor, like when transferring to the CPU or converting to a vector type. This effectively makes all methods that call resolve() async (see function coloring).

With this implemented, you should find yourself fighting WebGPU quite a bit less. And if you’ve brought in torch.nn analogues (modules, optimizers, a dataloader, etc.) you should have everything you need to write a training loop. That is, if you’ve also implemented your training loop in WebAssembly via Rust or C++, and can drop Tensor objects as they leave scope, RAII-style. But if you, like me, defined your training loop in JavaScript and consequently let the browser manage your Tensor references, you’ll find yourself cursing unpredictable garbage collection behavior, which brings us to the last major consideration I have to offer you.

When JavaScript’s garbage collector won’t do, just build a worse one

JavaScript has no garbage collection API! There is no supported way to clean up unused references in an application. Usually this is okay—presumably garbage collection algorithms monitor memory pressure and clean things up more often when we need them to. But garbage collection falls completely flat when it comes to managing tensor references, because it can see only the relatively little memory occupied by the references, and has no concept of the full iceberg of associated VRAM. This becomes a real problem, and because your training loop creates hundreds of intermediate tensors each step, you’ll blow through VRAM the garbage collector could not care less about. So you’ll want a way to clean up tensor references when you’re done with them.

Before you look at what I did—if you’re taking notes—you might consider noodling this one by yourself for a bit, because I’m not convinced my solution is obvious or optimal. But my thought process went as follows: in lieu of the sort of reference-counted data lifetimes I’d gotten used to with Rust, I wanted to simulate RAII for particularly hot manually-defined “scopes,” like a single training step, or a forward pass when sampling autoregressively. So I’d need a way to keep track of all the tensors created at the beginning of this virtual scope, and manually deallocate them at the end. To this end, I implemented another PyTorch feature, function modes, in Piston, and created WeakTensorMode, which does exactly that: tracking tensors and manually deallocating them in the cleanup (Symbol.dispose) method.

In practice, I found it’s also useful to have a pin(tensor) method that can persist tensors created during the function mode by removing them from its cleanup list. This comes in handy when, for example, you want to clean up forward-pass intermediates but keep outputs like logits or loss. Support for nested function modes becomes useful here, because the outermost weak mode can make sure anything pinned by an inner mode is still cleaned up eventually. Additionally, if you find, as I did, that you often want to drop all but the output tensor(s) of a given closure, it’s worth creating a globally-accessible weak(closure) convenience function that internally creates a weak mode, runs the closure, and pins any tensors in the returned object before cleaning up.

A yourtorch training loop

At this point, putting together everything we’ve done so far, you should be able to write a simple training loop that looks roughly like the following:

// sgd.ts

export class SGD extends Optimizer {
    // …

    async step(...) {
        // Do optimizer updates with gradients

        await this.device.markStep();

        return loss;
    }
}
// train.ts

import { Dataloader, weak, cpu, gpu } from 'yourtorch';
import { log } from './metrics';
import { YourModel } from './model';
import { YourDataset } from './model';
import { collateFn } from './collate';
import { SGD } from './sgd';

async function train() {
    const model = new YourModel();

    const trainDataloader = new Dataloader(new YourDataset({ split: "train" }), { batchSize: 32, collateFn });
    const validationDataloader = new Dataloader(new YourDataset({ split: "validation" }), { collateFn });

    const optimizer = new SGD(model.parameters(), gpu, { lr: 1e-5, /* other parameters */ });

    model.train();

    let step = 0;
    for (batch of trainDataloader) {
        await weak((mode) => {
            // Make sure we delete the batch at the end of the step
            mode.markWeak(batch);
            const [inputs, targets] = batch;

            const loss = await weak(() => {
                const [logits, loss] = model.forward(
                    await inputs.to(gpu), await targets.to(gpu)
                );

                // Loss will be pinned automatically because it's a return type
                // but the outer weak mode will drop it
                return loss;
            });

            loss.backward();
            await optimizer.step();
            this.optimizer.zeroGrad();

            const validationLoss = await weak(() => {
                // compute validation metrics
            });

            log({ "train/loss": await (await loss.to(cpu)).item(), "validation/loss": await (await validationLoss.to(cpu)).item() }, { step });

            step++;
        });
    }
}

Given that you’ve probably used PyTorch, this should look familiar except for the changes we added in the last two sections. And that’s it! Add model and optimizer implementations (you can look at mine if you like) and you can train your own tiny language models in yourtorch. Or, if you’d rather not wait, go check out Sequence Toy!


Thanks to Grant Pitt and Luci Sullivan for providing feedback on versions of this post, David Neto of Google’s Dawn team for answering a WebGPU question, and Christopher Fleetwood for pushing me to write the damn thing.


Appendix: visualizing activations

One unique advantage of being on the web platform—and not server Python—is that it is geared toward UI, so you might imagine that it’s relatively easier to make your neural network trainer even more impractical by adding support for visualizing neural network activations.

To help you weigh whether you should do this, consider some of the things you’ve already done to solve other problems that serendipitously make this easier:

  • Function modes make hooking individual operations much easier; getting activations on the fly could have otherwise been much more tedious, and possibly less granular.
  • You can display WebGPU buffers on WebGPU canvases with WebGPU shaders without first transferring to the CPU.

If you want to go the whole nine yards and define a query language for your project, like I did with Sequence Toy, you’ll still need to define, at a high level, a grammar (I used lezer, which is an especially good choice if you plan on using CodeMirror at any point), a parser, and an interpreter responsible for actually grabbing the activations. For this, I used a function mode as well as an implementation of forward hooks and forward pre-hooks.


  1. Dialects for Humans: Sounding Distinct from LLMs - Ben Gubler
  2. I will support Piston only as much as I feel like it.
  3. OpenAI o1, DeepSeek R1
  4. https://dynomight.net/gpt-2/
  5. This is not completely true—it might let you scale to two, if your browser can determine that one of them is more powerful: you can tell the browser that you’d like a “high-performance” or “low-power” adapter using powerPreference.
  6. With a serious leg up from Ratchet, which was my starting point for Piston. I’ll mention this a few times.
  7. By way of Ratchet.
  8. In addition to the 38 compute shaders Ratchet implements, I wrote 43: flip, affine, alibi, arange, bernoulli, binary (pow, maximum, minimum), cmp (eq, ne, le, ge, lt, gt, logical and, logical or, logical xor), eye, pointwise fill, gather, index add, lerp, multinomial, one hot, pow, reduce (sum, min, max, argmin, argmax, norm2), scatter add, ternary (addcdiv, addcmul), topk, tril/triu, unary (relu2, reciprocal, swiglu, logical not, isnan, isinf), where.
  9. Neural networks and deep learning, Calculus on Computational Graphs: Backpropagation — colah’s blog, 3Blue1Brown - Backpropagation calculus, GitHub - karpathy/micrograd, candle/candle-core/src/backprop.rs.
  10. Ratchet has the seeds of this with in-place memory detection.
  11. More like CUDAGraphs than XLA. I confirmed that graph-level optimization is out of scope for WebGPU: see Google Groups discussion. Thanks to David Neto and inner-daemons.

Citation

@online{vinhowe2025,
  author = {Vin Howe},
  title = {You probably shouldn’t train a language model in your browser—here’s how},
  date = {2025-11-21},
  url = {http://sveltekit-prerender/blog/train-a-language-model-in-your-browser}
}