Hi Qwen team, thanks for releasing Qwen3Guard-Stream!
While integrating the model we noticed a performance issue in stream_generate and wanted to share our findings.
Bug Description
stream_generate in modeling_qwen3_guard.py passes the full growing token sequence on every call, manually rebuilding the causal mask each time. The past_key_values returned by forward() are stored but never actually reused — the next call always starts from scratch with the full sequence.
This makes per-token cost O(N) instead of O(1), causing latency to grow with conversation length.
Root Cause
In the original implementation:
# Every iteration:
current_input_ids = torch.cat([current_input_ids, torch.tensor([next_token_id], device=self.device)])
# ... rebuild causal_mask manually ...
outputs = self.forward(
input_ids=current_input_ids.unsqueeze(0), # full sequence every time
attention_mask=causal_mask,
past_key_values=past_key_values # stored but ignored on next call
)
The full current_input_ids (length N) is passed on every step. past_key_values is returned but thrown away because the next call reprocesses the entire history. Additionally, use_cache=True is not passed to forward(), so KV cache is not actually populated.
Fix
Two changes are required together.
1. stream_generate — use KV cache correctly
On the first call process the full initial sequence and populate KV cache, then on each subsequent call pass only the single new token:
@torch.no_grad()
def stream_generate(self, input_ids):
# First call: full initial sequence, populate KV cache
outputs = self.forward(
input_ids=input_ids.unsqueeze(0),
use_cache=True,
logits_to_keep=1,
)
past_key_values = outputs.past_key_values
next_token_id = yield (
outputs.risk_level_logits, outputs.category_logits,
outputs.query_risk_level_logits, outputs.query_category_logits,
)
while next_token_id is not None:
# Subsequent calls: single new token only
outputs = self.forward(
input_ids=next_token_id.reshape(1, 1),
use_cache=True,
past_key_values=past_key_values,
logits_to_keep=1,
)
past_key_values = outputs.past_key_values
next_token_id = yield (
outputs.risk_level_logits, outputs.category_logits,
outputs.query_risk_level_logits, outputs.query_category_logits,
)
2. stream_moderate_from_ids — fix result extraction
Using logits_to_keep=1 makes every call return logits of shape (1, 1, num_classes) instead of (1, N, num_classes). After .squeeze(1) and torch.max(dim=-1) this yields a 0-d scalar tensor, not a 1-D tensor — so the existing list comprehensions fail with TypeError: cannot iterate over a 0-d tensor.
Replace the list comprehensions with direct scalar indexing:
# user role
result = {
"risk_level": [self.query_risk_level_map[int(pred_risk_idx[0])]],
"risk_prob": [round(float(pred_risk_prob[0]), 2)],
"category": [self.query_category_map[int(pred_cat_idx[0])]],
"category_prob": [round(float(pred_cat_prob[0]), 2)],
}
# assistant role — same pattern with response_* maps
result = {
"risk_level": [self.response_risk_level_map[int(pred_risk_idx[0])]],
"risk_prob": [round(float(pred_risk_prob[0]), 2)],
"category": [self.response_category_map[int(pred_cat_idx[0])]],
"category_prob": [round(float(pred_cat_prob[0]), 2)],
}
Note: logits_to_keep=1 is not just an optimisation here — it is required for the result extraction fix to work correctly on the user's first call (which otherwise produces pred_risk_idx[0] of shape (N,), where int() would fail).
Benchmark Results
Measured on RTX 4090, Qwen3Guard-Stream-0.6B, conversation of 15 tokens (1 user + 13 assistant tokens), 20 runs:
| Configuration |
Avg ms/token |
| Without fix |
~25 ms |
| With fix |
~19 ms |
The growing-sequence behavior is clearly visible in per-token timing without the fix — latency increases token by token as the sequence grows. At longer conversations the degradation would be significantly more pronounced.
Hi Qwen team, thanks for releasing Qwen3Guard-Stream!
While integrating the model we noticed a performance issue in
stream_generateand wanted to share our findings.Bug Description
stream_generateinmodeling_qwen3_guard.pypasses the full growing token sequence on every call, manually rebuilding the causal mask each time. Thepast_key_valuesreturned byforward()are stored but never actually reused — the next call always starts from scratch with the full sequence.This makes per-token cost O(N) instead of O(1), causing latency to grow with conversation length.
Root Cause
In the original implementation:
The full
current_input_ids(length N) is passed on every step.past_key_valuesis returned but thrown away because the next call reprocesses the entire history. Additionally,use_cache=Trueis not passed toforward(), so KV cache is not actually populated.Fix
Two changes are required together.
1.
stream_generate— use KV cache correctlyOn the first call process the full initial sequence and populate KV cache, then on each subsequent call pass only the single new token:
2.
stream_moderate_from_ids— fix result extractionUsing
logits_to_keep=1makes every call return logits of shape(1, 1, num_classes)instead of(1, N, num_classes). After.squeeze(1)andtorch.max(dim=-1)this yields a 0-d scalar tensor, not a 1-D tensor — so the existing list comprehensions fail withTypeError: cannot iterate over a 0-d tensor.Replace the list comprehensions with direct scalar indexing:
Note:
logits_to_keep=1is not just an optimisation here — it is required for the result extraction fix to work correctly on the user's first call (which otherwise producespred_risk_idx[0]of shape(N,), whereint()would fail).Benchmark Results
Measured on RTX 4090,
Qwen3Guard-Stream-0.6B, conversation of 15 tokens (1 user + 13 assistant tokens), 20 runs:The growing-sequence behavior is clearly visible in per-token timing without the fix — latency increases token by token as the sequence grows. At longer conversations the degradation would be significantly more pronounced.