Skip to content

Is PyTorch a hard requirement, or just the assumed default? #101

@pszemraj

Description

@pszemraj

Searching the repo, PyTorch is clearly assumed but I couldn't find explicit clarification on whether it's a hard requirement. This is related to #17 but specific enough for its own issue.

This matters because a stated goal is 'fostering architectural innovation', and torch's compile/runtime stack is biased toward Llama-style transformers1. Over the 2.x versions, support and performance for ops outside that mold (complex-valued ops, non-standard scan patterns, so on) remains minimal or absent. Constraining to PyTorch + fixed GPU time risks producing results that are artifacts of these biases/constraints rather than genuine architectural insights-likely converging on weird Llama ablations that happen to fit the size constraint2

Concrete example from my own work with MEGALODON (trainable/interpretable reimplementation) to explore inductive bias:

  • PyTorch (repo): 5.9% lower val loss than Llama with 10% fewer params on equal steps, but ~2.6x slower / ~1.7x more VRAM because torch support for its complex ops is bad.
  • JAX (repo): native JIT compilation of those same ops closes the gap entirely:
Model Params Val Loss @1200 BPC Time
Megalodon 11.28M 1.49 2.15 ~3.2m
Llama 12.49M 1.53 2.21 ~3.1m

Relevance to this challenge: an arch that's demonstrably better per-parameter would never place on the leaderboard because the wallclock penalty is a PyTorch penalty, not a model penalty.

I don't have a clean fix-some ideas would be to include select alternative frameworks, normalizing for FLOP utilization instead of raw wallclock, or a separate "parameter-efficiency" track from the "train-in-10-min" track. Open to discussion. The broader point is that ML research is more constrained by framework assumptions than we tend to acknowledge, and this challenge-explicitly about architectural creativity-is a good place to surface this

Footnotes

  1. given that they made llama, I can't really blame them here

  2. check out the bittensor pretrain subnet challenge + best models (aka subnet9/sn9, some examples) and how things were optimized over time, also had a fixed size but unlimited training time, validation was on fineweb/similar. tl;dr a bunch of weird llama architecture ablations (at the time)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions