Skip to content

Bug: stream_generate passes full growing sequence on every call instead of using KV cache #22

Description

@alexsarrell

Hi Qwen team, thanks for releasing Qwen3Guard-Stream!

While integrating the model we noticed a performance issue in stream_generate and wanted to share our findings.

Bug Description

stream_generate in modeling_qwen3_guard.py passes the full growing token sequence on every call, manually rebuilding the causal mask each time. The past_key_values returned by forward() are stored but never actually reused — the next call always starts from scratch with the full sequence.

This makes per-token cost O(N) instead of O(1), causing latency to grow with conversation length.

Root Cause

In the original implementation:

# Every iteration:
current_input_ids = torch.cat([current_input_ids, torch.tensor([next_token_id], device=self.device)])
# ... rebuild causal_mask manually ...
outputs = self.forward(
    input_ids=current_input_ids.unsqueeze(0),  # full sequence every time
    attention_mask=causal_mask,
    past_key_values=past_key_values             # stored but ignored on next call
)

The full current_input_ids (length N) is passed on every step. past_key_values is returned but thrown away because the next call reprocesses the entire history. Additionally, use_cache=True is not passed to forward(), so KV cache is not actually populated.

Fix

Two changes are required together.

1. stream_generate — use KV cache correctly

On the first call process the full initial sequence and populate KV cache, then on each subsequent call pass only the single new token:

@torch.no_grad()
def stream_generate(self, input_ids):
    # First call: full initial sequence, populate KV cache
    outputs = self.forward(
        input_ids=input_ids.unsqueeze(0),
        use_cache=True,
        logits_to_keep=1,
    )
    past_key_values = outputs.past_key_values
    next_token_id = yield (
        outputs.risk_level_logits, outputs.category_logits,
        outputs.query_risk_level_logits, outputs.query_category_logits,
    )
    while next_token_id is not None:
        # Subsequent calls: single new token only
        outputs = self.forward(
            input_ids=next_token_id.reshape(1, 1),
            use_cache=True,
            past_key_values=past_key_values,
            logits_to_keep=1,
        )
        past_key_values = outputs.past_key_values
        next_token_id = yield (
            outputs.risk_level_logits, outputs.category_logits,
            outputs.query_risk_level_logits, outputs.query_category_logits,
        )

2. stream_moderate_from_ids — fix result extraction

Using logits_to_keep=1 makes every call return logits of shape (1, 1, num_classes) instead of (1, N, num_classes). After .squeeze(1) and torch.max(dim=-1) this yields a 0-d scalar tensor, not a 1-D tensor — so the existing list comprehensions fail with TypeError: cannot iterate over a 0-d tensor.

Replace the list comprehensions with direct scalar indexing:

# user role
result = {
    "risk_level":    [self.query_risk_level_map[int(pred_risk_idx[0])]],
    "risk_prob":     [round(float(pred_risk_prob[0]), 2)],
    "category":      [self.query_category_map[int(pred_cat_idx[0])]],
    "category_prob": [round(float(pred_cat_prob[0]), 2)],
}
# assistant role — same pattern with response_* maps
result = {
    "risk_level":    [self.response_risk_level_map[int(pred_risk_idx[0])]],
    "risk_prob":     [round(float(pred_risk_prob[0]), 2)],
    "category":      [self.response_category_map[int(pred_cat_idx[0])]],
    "category_prob": [round(float(pred_cat_prob[0]), 2)],
}

Note: logits_to_keep=1 is not just an optimisation here — it is required for the result extraction fix to work correctly on the user's first call (which otherwise produces pred_risk_idx[0] of shape (N,), where int() would fail).

Benchmark Results

Measured on RTX 4090, Qwen3Guard-Stream-0.6B, conversation of 15 tokens (1 user + 13 assistant tokens), 20 runs:

Configuration Avg ms/token
Without fix ~25 ms
With fix ~19 ms

The growing-sequence behavior is clearly visible in per-token timing without the fix — latency increases token by token as the sequence grows. At longer conversations the degradation would be significantly more pronounced.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Fields

    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions