Skip to content

Commit e70681f

Browse files
committed
Auto merge of #12992 - lowr:fix/type-inference-for-byte-string-pat, r=Veykril
fix: infer byte string pattern as `&[u8]` when matched against slices Fixes #12630 c.f. [rustc_typeck](https://github.com/rust-lang/rust/blob/1603a70f82240ba2d27f72f964e36614d7620ad3/compiler/rustc_typeck/src/check/pat.rs#L388-L404)
2 parents d79d9e1 + ffc6b42 commit e70681f

File tree

2 files changed

+74
-4
lines changed

2 files changed

+74
-4
lines changed

crates/hir-ty/src/infer/pat.rs

+29-4
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@ use crate::{
1414
consteval::intern_const_scalar,
1515
infer::{BindingMode, Expectation, InferenceContext, TypeMismatch},
1616
lower::lower_to_chalk_mutability,
17-
static_lifetime, ConcreteConst, ConstValue, Interner, Substitution, Ty, TyBuilder, TyExt,
18-
TyKind,
17+
primitive::UintTy,
18+
static_lifetime, ConcreteConst, ConstValue, Interner, Scalar, Substitution, Ty, TyBuilder,
19+
TyExt, TyKind,
1920
};
2021

2122
use super::PatLike;
@@ -294,7 +295,29 @@ impl<'a> InferenceContext<'a> {
294295
let start_ty = self.infer_expr(*start, &Expectation::has_type(expected.clone()));
295296
self.infer_expr(*end, &Expectation::has_type(start_ty))
296297
}
297-
Pat::Lit(expr) => self.infer_expr(*expr, &Expectation::has_type(expected.clone())),
298+
&Pat::Lit(expr) => {
299+
// FIXME: using `Option` here is a workaround until we can use if-let chains in stable.
300+
let mut pat_ty = None;
301+
302+
// Like slice patterns, byte string patterns can denote both `&[u8; N]` and `&[u8]`.
303+
if let Expr::Literal(Literal::ByteString(_)) = self.body[expr] {
304+
if let Some((inner, ..)) = expected.as_reference() {
305+
let inner = self.resolve_ty_shallow(inner);
306+
if matches!(inner.kind(Interner), TyKind::Slice(_)) {
307+
let elem_ty = TyKind::Scalar(Scalar::Uint(UintTy::U8)).intern(Interner);
308+
let slice_ty = TyKind::Slice(elem_ty).intern(Interner);
309+
let ty = TyKind::Ref(Mutability::Not, static_lifetime(), slice_ty)
310+
.intern(Interner);
311+
self.write_expr_ty(expr, ty.clone());
312+
pat_ty = Some(ty);
313+
}
314+
}
315+
}
316+
317+
pat_ty.unwrap_or_else(|| {
318+
self.infer_expr(expr, &Expectation::has_type(expected.clone()))
319+
})
320+
}
298321
Pat::Box { inner } => match self.resolve_boxed_box() {
299322
Some(box_adt) => {
300323
let (inner_ty, alloc_ty) = match expected.as_adt() {
@@ -343,7 +366,9 @@ fn is_non_ref_pat(body: &hir_def::body::Body, pat: PatId) -> bool {
343366
// FIXME: ConstBlock/Path/Lit might actually evaluate to ref, but inference is unimplemented.
344367
Pat::Path(..) => true,
345368
Pat::ConstBlock(..) => true,
346-
Pat::Lit(expr) => !matches!(body[*expr], Expr::Literal(Literal::String(..))),
369+
Pat::Lit(expr) => {
370+
!matches!(body[*expr], Expr::Literal(Literal::String(..) | Literal::ByteString(..)))
371+
}
347372
Pat::Bind {
348373
mode: BindingAnnotation::Mutable | BindingAnnotation::Unannotated,
349374
subpat: Some(subpat),

crates/hir-ty/src/tests/patterns.rs

+45
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,51 @@ fn infer_pattern_match_string_literal() {
315315
);
316316
}
317317

318+
#[test]
319+
fn infer_pattern_match_byte_string_literal() {
320+
check_infer_with_mismatches(
321+
r#"
322+
//- minicore: index
323+
struct S;
324+
impl<T, const N: usize> core::ops::Index<S> for [T; N] {
325+
type Output = [u8];
326+
fn index(&self, index: core::ops::RangeFull) -> &Self::Output {
327+
loop {}
328+
}
329+
}
330+
fn test(v: [u8; 3]) {
331+
if let b"foo" = &v[S] {}
332+
if let b"foo" = &v {}
333+
}
334+
"#,
335+
expect![[r#"
336+
105..109 'self': &[T; N]
337+
111..116 'index': {unknown}
338+
157..180 '{ ... }': &[u8]
339+
167..174 'loop {}': !
340+
172..174 '{}': ()
341+
191..192 'v': [u8; 3]
342+
203..261 '{ ...v {} }': ()
343+
209..233 'if let...[S] {}': ()
344+
212..230 'let b"... &v[S]': bool
345+
216..222 'b"foo"': &[u8]
346+
216..222 'b"foo"': &[u8]
347+
225..230 '&v[S]': &[u8]
348+
226..227 'v': [u8; 3]
349+
226..230 'v[S]': [u8]
350+
228..229 'S': S
351+
231..233 '{}': ()
352+
238..259 'if let... &v {}': ()
353+
241..256 'let b"foo" = &v': bool
354+
245..251 'b"foo"': &[u8; 3]
355+
245..251 'b"foo"': &[u8; 3]
356+
254..256 '&v': &[u8; 3]
357+
255..256 'v': [u8; 3]
358+
257..259 '{}': ()
359+
"#]],
360+
);
361+
}
362+
318363
#[test]
319364
fn infer_pattern_match_or() {
320365
check_infer_with_mismatches(

0 commit comments

Comments
 (0)