diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/HttpBindingGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/HttpBindingGenerator.kt index f52f4d38afd..47e1240a091 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/HttpBindingGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/HttpBindingGenerator.kt @@ -471,13 +471,15 @@ class HttpBindingGenerator( } private fun RustWriter.renderHeaders(httpBinding: HttpBinding) { + check(httpBinding.location == HttpLocation.HEADER) val memberShape = httpBinding.member - val memberType = model.expectShape(memberShape.target) + val targetShape = model.expectShape(memberShape.target) val memberSymbol = symbolProvider.toSymbol(memberShape) val memberName = symbolProvider.toMemberName(memberShape) - ifSet(memberType, memberSymbol, "&input.$memberName") { field -> - val isListHeader = memberType is CollectionShape - listForEach(memberType, field) { innerField, targetId -> + + ifSet(targetShape, memberSymbol, "&input.$memberName") { field -> + val isListHeader = targetShape is CollectionShape + listForEach(targetShape, field) { innerField, targetId -> val innerMemberType = model.expectShape(targetId) if (innerMemberType.isPrimitive()) { val encoder = CargoDependency.SmithyTypes(runtimeConfig).asType().member("primitive::Encoder") @@ -491,9 +493,14 @@ class HttpBindingGenerator( """ let header_value = $safeName; let header_value = http::header::HeaderValue::try_from(&*header_value).map_err(|err| { - #{build_error}::InvalidField { field: "$memberName", details: format!("`{}` cannot be used as a header value: {}", &${ - memberShape.redactIfNecessary(model, "header_value") - }, err)} + #{build_error}::InvalidField { + field: "$memberName", + details: format!( + "`{}` cannot be used as a header value: {}", + &${memberShape.redactIfNecessary(model, "header_value")}, + err, + ) + } })?; builder = builder.header("${httpBinding.locationName}", header_value); """, @@ -505,17 +512,14 @@ class HttpBindingGenerator( } private fun RustWriter.renderPrefixHeader(httpBinding: HttpBinding) { + check(httpBinding.location == HttpLocation.PREFIX_HEADERS) val memberShape = httpBinding.member - val memberType = model.expectShape(memberShape.target) + val targetShape = model.expectShape(memberShape.target, MapShape::class.java) val memberSymbol = symbolProvider.toSymbol(memberShape) val memberName = symbolProvider.toMemberName(memberShape) - val target = when (memberType) { - is CollectionShape -> model.expectShape(memberType.member.target) - is MapShape -> model.expectShape(memberType.value.target) - else -> UNREACHABLE("unexpected member for prefix headers: $memberType") - } - ifSet(memberType, memberSymbol, "&input.$memberName") { field -> - val listHeader = memberType is CollectionShape + val valueTargetShape = model.expectShape(targetShape.value.target) + + ifSet(targetShape, memberSymbol, "&input.$memberName") { field -> rustTemplate( """ for (k, v) in $field { @@ -523,17 +527,19 @@ class HttpBindingGenerator( let header_name = http::header::HeaderName::from_str(&format!("{}{}", "${httpBinding.locationName}", &k)).map_err(|err| { #{build_error}::InvalidField { field: "$memberName", details: format!("`{}` cannot be used as a header name: {}", k, err)} })?; - let header_value = ${headerFmtFun(this, target, memberShape, "v", listHeader)}; + let header_value = ${headerFmtFun(this, valueTargetShape, memberShape, "v", isListHeader = false)}; let header_value = http::header::HeaderValue::try_from(&*header_value).map_err(|err| { #{build_error}::InvalidField { field: "$memberName", - details: format!("`{}` cannot be used as a header value: {}", ${ - memberShape.redactIfNecessary(model, "v") - }, err)} + details: format!( + "`{}` cannot be used as a header value: {}", + ${memberShape.redactIfNecessary(model, "v")}, + err, + ) + } })?; builder = builder.header(header_name, header_value); } - """, "build_error" to runtimeConfig.operationBuildError(), )