Skip to content

Commit 77993db

Browse files
committed
Fix half_to_float handling
1 parent ebaa53f commit 77993db

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,8 @@ def _populate_trt_builder_config(
300300
if tactic_sources is not None:
301301
builder_config.set_tactic_sources(tactic_sources=tactic_sources)
302302

303+
builder_config.set_flag(trt.BuilderFlag.PREFER_PRECISION_CONSTRAINTS)
304+
303305
return builder_config
304306

305307
def _create_timing_cache(

py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -431,12 +431,13 @@ def softmax(
431431
) -> Union[TRTTensor, Sequence[TRTTensor]]:
432432
dim = get_positive_dim(dim, len(input.shape))
433433

434-
if half_to_float:
435-
input = cast_trt_tensor(ctx, input, torch.float, name, target, source_ir)
436-
437434
layer = ctx.net.add_softmax(input)
438435
layer.axes = 1 << dim
439436
set_layer_name(layer, target, name, source_ir)
437+
438+
if half_to_float:
439+
layer.precision = trt.DataType.FLOAT
440+
440441
return layer.get_output(0)
441442

442443

0 commit comments

Comments
 (0)