Skip to content

fix: squeeze 3D action tensor in LinUCB learn_batch#129

Open
dashitongzhi wants to merge 1 commit into
facebookresearch:mainfrom
dashitongzhi:fix/torchscript-bandit
Open

fix: squeeze 3D action tensor in LinUCB learn_batch#129
dashitongzhi wants to merge 1 commit into
facebookresearch:mainfrom
dashitongzhi:fix/torchscript-bandit

Conversation

@dashitongzhi
Copy link
Copy Markdown

Problem

When running the contextual_bandits_tutorial with LinUCB, training fails due to tensor dimension mismatch:

  • batch.state shape: [1, 16]
  • batch.action shape: [1, 1, 10] (one-hot encoded, 3D)

torch.cat([batch.state, batch.action], dim=1) fails because state is 2D but action is 3D.

Fix

Added torch.squeeze(batch.action, dim=1) to handle 3D action tensors. This safely converts [B, 1, N][B, N] while leaving already-2D [B, N] tensors unchanged (squeeze is a no-op when dim=1 has size > 1).

Fixes #125

batch.action can have shape [B, 1, N] for one-hot encoded actions,
but torch.cat with batch.state (shape [B, D]) requires 2D tensors.
Squeeze dim=1 to handle both [B, N] and [B, 1, N] action shapes.

Fixes facebookresearch#125
Copilot AI review requested due to automatic review settings May 8, 2026 15:51
@meta-cla
Copy link
Copy Markdown

meta-cla Bot commented May 8, 2026

Hi @dashitongzhi!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks!

Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR addresses a shape mismatch in LinearBandit.learn_batch() (LinUCB) when TransitionBatch.action is provided as a 3D tensor (e.g., [B, 1, N] one-hot), which breaks feature concatenation with 2D batch.state.

Changes:

  • Squeezes batch.action on dimension 1 before concatenating with batch.state in LinearBandit.learn_batch().

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

else torch.ones_like(expected_values)
)
x = torch.cat([batch.state, batch.action], dim=1)
x = torch.cat([batch.state, torch.squeeze(batch.action, dim=1)], dim=1)
@dashitongzhi dashitongzhi reopened this May 9, 2026
@meta-cla
Copy link
Copy Markdown

meta-cla Bot commented May 9, 2026

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 9, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Contextual Bandits LInUCB errors during training due to mismatch action and state tensor size

2 participants