Last I checked , all of pytorch, tensorflow, and Jax sit at a layer of abstraction that is above GPU kernels. They avail GPU kernels (as basically nodes in the computational graph you mention), but they don't let you write GPU kernels.
Yes, they kinda do. The computational graph you specify is completely different from the execution schedule it is compiled into. Whether it's 1, 2, or N kernels is irrelevant as long as it runs fast. Mojo being an HLL is conceptually no different from Python. Whether it will, in the future, become better for DNNs, time will tell.
I assume HLL=Higher Level Language? Mojo definitely avails lower-level facilities than Python. Chris has even described Mojo as "syntactic sugar over MLIR". (For example, the native integer type is defined in library code as a struct).
> Whether it's 1, 2, or N kernels is irrelevant.
Not sure what you mean here. But new kernels are written all the time (flash-attn is a great example). One can't do that in plain Python. E.g., flash-attn was originally written in C++ CUDA, and now in Triton.
Well, Mojo hasn't been released so we can't precisely say what it can and can't do. If it can emit CUDA code then it does it by transpiling Mojo into CUDA. And there is no reason why Python can't also be transpiled to CUDA.
What I mean here is that DNN code is written on a much higher level than kernels. They are just building blocks you use to instantiate your dataflow.
Torch.compile sits at both the level of computation graph and GPU kernels and can fuse your operations by using triton compiler. I think something similar applies to Jax and tensorflow by the way of XLA, but I’m not 100% sure.
Good point. But the overall point about Mojo availing a different level of abstraction as compared to Python still stands: I imagine that no amount of magic/operator-fusion/etc in `torch.compile()` would let one get reasonable performance for an implementation of, say, flash-attn. One would have to use CUDA/Triton/Mojo/etc.
But python is already operating fully on different level of abstraction - you mention triton yourself, and there is new python cuda api too (the one similar to triton). More to this - flash attention 4 is actually written in python.
Somehow python managed to be both high level and low level language for GPUs…
IIUC, triton uses Python syntax, but it has a separate compiler (which is kinda what Mojo is doing, except Mojo's syntax is a superset of Python's, instead of a subset, like Triton). I think it's fair to describe it as a different language (otherwise, we'd also have to describe Mojo also as "Python"). Triton's website and repo describes itself as "the Triton language and compiler" (as opposed to, I dunno, "Write GPU kernels in Python").
Also, flash attention is at v3-beta right now? [0] And it requires one of CUDA/Triton/ROCm?
From this perspective PyTorch is separate language, at least as soon as you start using torch.compile (only subset of PyTorch python will be compilable). That’s strength of python - it’s great for describing things and later for analyzing them (and compiling, for example).
Just to be clear here - you use triton from plain python, it runs compilation inside.
Just like I’m pretty sure not all mojo can be used to write kernels? I might be wrong here, but it would be very hard to fit general purpose code into kernels (and to be frank pointless, constrains bring speed).
Triton, CUDA, etc, let one write GPU kernels.