Skip to content

Simple ConvNet causes mismatched dtypes during to_edge() call #8206

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
holgerroth opened this issue Feb 5, 2025 · 14 comments
Open

Simple ConvNet causes mismatched dtypes during to_edge() call #8206

holgerroth opened this issue Feb 5, 2025 · 14 comments
Labels
module: training Issues related to training models on edge devices triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@holgerroth
Copy link

holgerroth commented Feb 5, 2025

🐛 Describe the bug

Trying to export a simple ConvNet for CIFAR-10.

class ConvNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)  # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

The export script is here. The error happens during to_edge() call.

Traceback (most recent call last):
  File "/localhome/local-hroth/Code/executorch/extension/training/examples/XOR/export_model_cifar10.py", line 81, in <module>
    main()
  File "/localhome/local-hroth/Code/executorch/extension/training/examples/XOR/export_model_cifar10.py", line 68, in main
    ep = _export_model()
  File "/localhome/local-hroth/Code/executorch/extension/training/examples/XOR/export_model_cifar10.py", line 44, in _export_model
    ep = to_edge(ep)
  File "/localhome/local-hroth/.venv_executorch/lib/python3.10/site-packages/executorch/exir/program/_program.py", line 101, in wrapper
    return func(self, *args, **kwargs)
  File "/localhome/local-hroth/.venv_executorch/lib/python3.10/site-packages/executorch/exir/program/_program.py", line 1217, in to_edge
    edge_programs[name] = _generate_edge_program(name, config, program)
  File "/localhome/local-hroth/.venv_executorch/lib/python3.10/site-packages/executorch/exir/program/_program.py", line 799, in _generate_edge_program
    edge_program = ExportedProgram(
  File "/localhome/local-hroth/.venv_executorch/lib/python3.10/site-packages/torch/export/exported_program.py", line 916, in __init__
    self.validate()
  File "/localhome/local-hroth/.venv_executorch/lib/python3.10/site-packages/torch/export/exported_program.py", line 1466, in validate
    self._validate()
  File "/localhome/local-hroth/.venv_executorch/lib/python3.10/site-packages/torch/export/exported_program.py", line 1475, in _validate
    v().check(self)
  File "/localhome/local-hroth/.venv_executorch/lib/python3.10/site-packages/torch/_export/verifier.py", line 166, in check
    self._check_graph_module(ep.graph_module)
  File "/localhome/local-hroth/.venv_executorch/lib/python3.10/site-packages/torch/_export/verifier.py", line 290, in _check_graph_module
    self.check_additional(gm)
  File "/localhome/local-hroth/.venv_executorch/lib/python3.10/site-packages/executorch/exir/verification/verifier.py", line 284, in check_additional
    _check_tensor_args_matching_op_allowed_dtype(gm)
  File "/localhome/local-hroth/.venv_executorch/lib/python3.10/site-packages/executorch/exir/verification/verifier.py", line 205, in _check_tensor_args_matching_op_allowed_dtype
    raise SpecViolationError(
torch._export.verifier.SpecViolationError: These operators are taking Tensor inputs with mismatched dtypes:

Operator: <EdgeOpOverload: aten.convolution_backward.default>: schema = aten::convolution_backward(Tensor grad_output, Tensor input, Tensor weight, SymInt[]? bias_sizes, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor) with args: {'grad_output': torch.float32, 'input': torch.float32, 'weight': torch.float32, '__ret_0': torch.float32, '__ret_1': torch.float32}
stack trace: File "<eval_with_key>.21 from /localhome/local-hroth/.venv_executorch/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py:1180 in wrapped", line 124, in forward
    convolution_backward_1 = torch.ops.aten.convolution_backward.default(where_7, primals_11, primals_1, [6], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [False, True, True]);  where_7 = primals_11 = primals_1 = None
Please make sure the dtypes of the Tensor inputs are the same as the dtypes of the corresponding outputs.

Versions

Collecting environment information...
PyTorch version: 2.7.0.dev20250131+cpu
Is debug build: False
CUDA used to build PyTorch: Could not collect
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.31.4
Libc version: glibc-2.35

Python version: 3.10.12 (main, Nov 6 2024, 20:22:13) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-130-generic-x86_64-with-glibc2.35
Is CUDA available: False
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration:
GPU 0: NVIDIA A100 80GB PCIe
GPU 1: NVIDIA A100 80GB PCIe
GPU 2: NVIDIA A100 80GB PCIe
GPU 3: NVIDIA A100 80GB PCIe
GPU 4: NVIDIA A100 80GB PCIe
GPU 5: NVIDIA A100 80GB PCIe
GPU 6: NVIDIA A100 80GB PCIe
GPU 7: NVIDIA A100 80GB PCIe

Nvidia driver version: 550.120
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 43 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 128
On-line CPU(s) list: 0-127
Vendor ID: AuthenticAMD
Model name: AMD EPYC 7H12 64-Core Processor
CPU family: 23
Model: 49
Thread(s) per core: 1
Core(s) per socket: 64
Socket(s): 2
Stepping: 0
Frequency boost: enabled
CPU max MHz: 2600.0000
CPU min MHz: 1500.0000
BogoMIPS: 5199.82
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip rdpid overflow_recov succor smca sme sev sev_es
Virtualization: AMD-V
L1d cache: 4 MiB (128 instances)
L1i cache: 4 MiB (128 instances)
L2 cache: 64 MiB (128 instances)
L3 cache: 512 MiB (32 instances)
NUMA node(s): 2
NUMA node0 CPU(s): 0-63
NUMA node1 CPU(s): 64-127
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Mitigation; untrained return thunk; SMT disabled
Vulnerability Spec rstack overflow: Mitigation; SMT disabled
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines; IBPB conditional; STIBP disabled; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected

Versions of relevant libraries:
[pip3] executorch==0.6.0a0+1fda542
[pip3] numpy==2.0.0
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] torch==2.7.0.dev20250131+cpu
[pip3] torchao==0.8.0+git11333ba2
[pip3] torchaudio==2.6.0.dev20250131+cpu
[pip3] torchsr==1.0.4
[pip3] torchtune==0.5.0
[pip3] torchvision==0.22.0.dev20250131+cpu
[pip3] triton==3.1.0
[conda] Could not collect

cc @JacobSzwejbka

@digantdesai
Copy link
Contributor

digantdesai commented Feb 5, 2025

Hmm, I guess just forward works, and fails with backwards graph? Also I assume suppressing verified doesn't go too far either?

cc @JacobSzwejbka - can you help? Not sure if this is a verifier issue or something to do with export.

@digantdesai digantdesai added module: training Issues related to training models on edge devices triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Feb 5, 2025
@holgerroth
Copy link
Author

Correct. I tried adding this which let the to_edge() call pass

    edge_compile_config = EdgeCompileConfig(_check_ir_validity=False)
    ep = to_edge(ep, compile_config=edge_compile_config)  

but then it fails during the to_executorch() call

/localhome/local-hroth/.venv_executorch/lib/python3.10/site-packages/executorch/exir/passes/memory_planning_pass.py:83: UserWarning: Function aten.convolution_backward.out's out0 kwarg value is None
  warnings.warn(
Traceback (most recent call last):
  File "/localhome/local-hroth/Code/executorch/extension/training/examples/XOR/export_model_cifar10.py", line 81, in <module>
    main()
  File "/localhome/local-hroth/Code/executorch/extension/training/examples/XOR/export_model_cifar10.py", line 68, in main
    ep = _export_model()
  File "/localhome/local-hroth/Code/executorch/extension/training/examples/XOR/export_model_cifar10.py", line 48, in _export_model
    ep = ep.to_executorch()
  File "/localhome/local-hroth/.venv_executorch/lib/python3.10/site-packages/executorch/exir/program/_program.py", line 101, in wrapper
    return func(self, *args, **kwargs)
  File "/localhome/local-hroth/.venv_executorch/lib/python3.10/site-packages/executorch/exir/program/_program.py", line 1432, in to_executorch
    new_gm_res = memory_planning_pass.run(  # pyre-ignore[16]
  File "/localhome/local-hroth/.venv_executorch/lib/python3.10/site-packages/executorch/exir/passes/memory_planning_pass.py", line 116, in run
    self._set_alloc_node_spec(graph_module)
  File "/localhome/local-hroth/.venv_executorch/lib/python3.10/site-packages/executorch/exir/passes/memory_planning_pass.py", line 97, in _set_alloc_node_spec
    out_alloc_node.meta["spec"] = specs[i]
IndexError: list index out of range

@JacobSzwejbka
Copy link
Contributor

I think this is because the very first conv node is returning None for the input gradient instead of empty tensor. Let me follow up with compiler on this.

@holgerroth
Copy link
Author

FYI. I got it to work doing this in my export script (which shouldn't be the right way)

    net = TrainingNet(ConvNet())
    x = torch.randn(BATCH_SIZE, 3, 32, 32)
    x.requires_grad = True

but that causes an error during training

E 00:00:01.219379 executorch:operator_registry.cpp:186] kernel 'aten::max_pool2d_with_indices_backward.grad_input' not found.
E 00:00:01.219416 executorch:operator_registry.cpp:187] dtype: 6 | dim order: [
E 00:00:01.219418 executorch:operator_registry.cpp:187] 0,
E 00:00:01.219420 executorch:operator_registry.cpp:187] 1,
E 00:00:01.219421 executorch:operator_registry.cpp:187] 2,
E 00:00:01.219423 executorch:operator_registry.cpp:187] 3,
E 00:00:01.219424 executorch:operator_registry.cpp:187] ]
E 00:00:01.219426 executorch:operator_registry.cpp:187] dtype: 6 | dim order: [
E 00:00:01.219428 executorch:operator_registry.cpp:187] 0,
E 00:00:01.219429 executorch:operator_registry.cpp:187] 1,
E 00:00:01.219430 executorch:operator_registry.cpp:187] 2,
E 00:00:01.219432 executorch:operator_registry.cpp:187] 3,
E 00:00:01.219433 executorch:operator_registry.cpp:187] ]

@JacobSzwejbka
Copy link
Contributor

JacobSzwejbka commented Feb 5, 2025

kernel 'aten::max_pool2d_with_indices_backward.grad_input' not found.

We just dont have an implementation for this operator yet cc @manuelcandales, although if its possible to decomp it we should register one and do that instead likely.

FYI. I got it to work doing this in my export script (which shouldn't be the right way)

The problem with this approach is we have an invariant (that we apparently dont assert on, not great) that the number of grad outputs == the number of parameters we output as well.

The output of a training model is:

and grads and params should be mapped 1-1 (so the first grad output corresponds with the first param etc)

we then emit some hidden functions named something like

__executorch_gradient_start_index
__executorch_param_start_index

which tell us where in the outputs list do the respective groups start. If an input gradient is in the list the mapping gets broken which will mess up things like the TrainingModule (which wraps all of this under the hood and then gives you apis that behave how you would expect).

@JacobSzwejbka
Copy link
Contributor

After discussing some more with compiler apparently the None output is expected behavior. I can get super in the weeds here about whats going on if anyone cares to hear, but otherwise I am testing a fix locally right now and should have it up today or tomorrow.

@JacobSzwejbka
Copy link
Contributor

JacobSzwejbka commented Feb 5, 2025

Ok yeah I have a conv fix at least.

    class Net(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv1 = nn.Conv2d(6, 6, 5)
            self.linear = nn.Linear(6, 2)
    
        def forward(self, x):
            return self.linear(self.conv1(x).flatten(1))
    
    
    # On device training requires the loss to be embedded in the model (and be the first output).
    # We wrap the original model here and add the loss calculation. This will be the model we export.
    class TrainingNet(nn.Module):
        def __init__(self, net):
            super().__init__()
            self.net = net
            self.loss = nn.CrossEntropyLoss()
    
        def forward(self, input, label):
            pred = self.net(input)
            return self.loss(pred, label)
    
    
    def main() -> None:
        torch.manual_seed(0)
    
        net = TrainingNet(Net())
        inputs = (torch.randn(1, 6, 5, 5), torch.ones(1, dtype=torch.int64))
        x = net(*inputs)
        print("eager loss:", x)
        x.backward()
        print("eager conv1 weight grad:", net.net.conv1.weight.grad)
        print("eager conv1 bias grad:", net.net.conv1.bias.grad)
        print("eager linear weight grad:", net.net.linear.weight.grad)
        print("eager linear bias grad:", net.net.linear.bias.grad)
        print(net(*inputs))
        # Captures the forward graph. The graph will look similar to the model definition now.
        # Will move to export_for_training soon which is the api planned to be supported in the long term.
        ep = export(net, inputs)
        print("Exported Program:", ep.graph)
        # Captures the backward graph. The exported_program now contains the joint forward and backward graph.
        ep = _export_forward_backward(ep)
        print("Joint Program:", ep.graph)
        # Lower the graph to edge dialect.
        ep = to_edge(ep)
        # Lower the graph to executorch.
        ep = ep.to_executorch()
    
        # ep.dump_executorch_program(True)
    
        et = _load_for_executorch_from_buffer(ep.buffer)
        y = et(inputs)
        for i in y:
            print("et out ", i)
    
        print(torch.allclose(net.net.conv1.weight.grad, y[1], atol=1e-03))
        print(torch.allclose(net.net.conv1.bias.grad, y[2], atol=1e-03))
        print(torch.allclose(net.net.linear.weight.grad, y[3], atol=1e-03))
        print(torch.allclose(net.net.linear.bias.grad, y[4], atol=1e-03))
    
    
    if __name__ == "__main__":
        main()

is passing for me on my stack. Ill start putting em up and merging them.

@holgerroth
Copy link
Author

Sounds great! Thanks.

@JacobSzwejbka
Copy link
Contributor

#8303

should be fixed on the latest branch. Though the missing ops will still be around

@holgerroth
Copy link
Author

Thanks @JacobSzwejbka ! I can export the convent now but I still get this error during training with my train.cpp.

/cmake-out/extension/training/train_xor --model_path=/tmp/foobar/cifar10.pte
Loaded 10000 images from /localhome/local-hroth/Data/CIFAR10/cifar-10-batches-bin/data_batch_1.bin
First image label: 6
Built image_batch 98304
Built label_batch 32
E 00:00:01.395511 executorch:operator_registry.cpp:186] kernel 'aten::max_pool2d_with_indices_backward.grad_input' not found.
E 00:00:01.395555 executorch:operator_registry.cpp:187] dtype: 6 | dim order: [
E 00:00:01.395558 executorch:operator_registry.cpp:187] 0,
E 00:00:01.395559 executorch:operator_registry.cpp:187] 1,
E 00:00:01.395561 executorch:operator_registry.cpp:187] 2,
E 00:00:01.395562 executorch:operator_registry.cpp:187] 3,
E 00:00:01.395563 executorch:operator_registry.cpp:187] ]
E 00:00:01.395564 executorch:operator_registry.cpp:187] dtype: 6 | dim order: [
E 00:00:01.395566 executorch:operator_registry.cpp:187] 0,
E 00:00:01.395568 executorch:operator_registry.cpp:187] 1,
E 00:00:01.395569 executorch:operator_registry.cpp:187] 2,
E 00:00:01.395570 executorch:operator_registry.cpp:187] 3,
E 00:00:01.395572 executorch:operator_registry.cpp:187] ]
E 00:00:01.395573 executorch:operator_registry.cpp:187] dtype: 4 | dim order: [
E 00:00:01.395574 executorch:operator_registry.cpp:187] 0,
E 00:00:01.395575 executorch:operator_registry.cpp:187] 1,
E 00:00:01.395577 executorch:operator_registry.cpp:187] 2,
E 00:00:01.395578 executorch:operator_registry.cpp:187] 3,
E 00:00:01.395580 executorch:operator_registry.cpp:187] ]
E 00:00:01.395581 executorch:operator_registry.cpp:187] dtype: 6 | dim order: [
E 00:00:01.395587 executorch:operator_registry.cpp:187] 0,
E 00:00:01.395588 executorch:operator_registry.cpp:187] 1,
E 00:00:01.395590 executorch:operator_registry.cpp:187] 2,
E 00:00:01.395593 executorch:operator_registry.cpp:187] 3,
E 00:00:01.395597 executorch:operator_registry.cpp:187] ]
E 00:00:01.395598 executorch:operator_registry.cpp:187] dtype: 6 | dim order: [
E 00:00:01.395599 executorch:operator_registry.cpp:187] 0,
E 00:00:01.395602 executorch:operator_registry.cpp:187] 1,
E 00:00:01.395606 executorch:operator_registry.cpp:187] 2,
E 00:00:01.395607 executorch:operator_registry.cpp:187] 3,
E 00:00:01.395609 executorch:operator_registry.cpp:187] ]
E 00:00:01.395612 executorch:method.cpp:599] Missing operator: [26] aten::max_pool2d_with_indices_backward.grad_input
E 00:00:01.395624 executorch:operator_registry.cpp:186] kernel 'aten::max_pool2d_with_indices_backward.grad_input' not found.
E 00:00:01.395627 executorch:operator_registry.cpp:187] dtype: 6 | dim order: [
E 00:00:01.395629 executorch:operator_registry.cpp:187] 0,
E 00:00:01.395630 executorch:operator_registry.cpp:187] 1,
E 00:00:01.395632 executorch:operator_registry.cpp:187] 2,
E 00:00:01.395637 executorch:operator_registry.cpp:187] 3,
E 00:00:01.395639 executorch:operator_registry.cpp:187] ]
E 00:00:01.395640 executorch:operator_registry.cpp:187] dtype: 6 | dim order: [
E 00:00:01.395644 executorch:operator_registry.cpp:187] 0,
E 00:00:01.395646 executorch:operator_registry.cpp:187] 1,
E 00:00:01.395648 executorch:operator_registry.cpp:187] 2,
E 00:00:01.395650 executorch:operator_registry.cpp:187] 3,
E 00:00:01.395652 executorch:operator_registry.cpp:187] ]
E 00:00:01.395656 executorch:operator_registry.cpp:187] dtype: 4 | dim order: [
E 00:00:01.395659 executorch:operator_registry.cpp:187] 0,
E 00:00:01.395662 executorch:operator_registry.cpp:187] 1,
E 00:00:01.395665 executorch:operator_registry.cpp:187] 2,
E 00:00:01.395668 executorch:operator_registry.cpp:187] 3,
E 00:00:01.395669 executorch:operator_registry.cpp:187] ]
E 00:00:01.395671 executorch:operator_registry.cpp:187] dtype: 6 | dim order: [
E 00:00:01.395672 executorch:operator_registry.cpp:187] 0,
E 00:00:01.395676 executorch:operator_registry.cpp:187] 1,
E 00:00:01.395677 executorch:operator_registry.cpp:187] 2,
E 00:00:01.395679 executorch:operator_registry.cpp:187] 3,
E 00:00:01.395682 executorch:operator_registry.cpp:187] ]
E 00:00:01.395683 executorch:operator_registry.cpp:187] dtype: 6 | dim order: [
E 00:00:01.395685 executorch:operator_registry.cpp:187] 0,
E 00:00:01.395688 executorch:operator_registry.cpp:187] 1,
E 00:00:01.395690 executorch:operator_registry.cpp:187] 2,
E 00:00:01.395691 executorch:operator_registry.cpp:187] 3,
E 00:00:01.395693 executorch:operator_registry.cpp:187] ]
E 00:00:01.395694 executorch:method.cpp:599] Missing operator: [26] aten::max_pool2d_with_indices_backward.grad_input
E 00:00:01.395706 executorch:method.cpp:816] There are 2 instructions don't have corresponding operator registered. See logs for details
E 00:00:01.395860 executorch:train.cpp:156] Failed to get named parameters

Do you mean this by missing ops?

@holgerroth
Copy link
Author

@JacobSzwejbka any updates on supporting the Conv2D operation for training? CC @YuanTingHsieh

@JacobSzwejbka
Copy link
Contributor

Sorry missed this update. Yes this is what I meant by missing op.

@JacobSzwejbka
Copy link
Contributor

@holgerroth CIFAR is easy enough that you could converge here without a max pool layer right? Can you just drop it for now to unblock yourself while I work on getting it decompd/implemented?

@holgerroth
Copy link
Author

I manged to get it to work with nn.Conv2d(..., stride=2) and skipping the max pooling layers. Still, would be nice to support max pooling as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: training Issues related to training models on edge devices triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants