Skip to content

Commit 6b74b00

Browse files
committed
Reformatted tensor parallelism
1 parent dfc65d0 commit 6b74b00

File tree

3 files changed

+9
-23
lines changed

3 files changed

+9
-23
lines changed

examples/distributed_inference/llama3_model.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44

55
from dataclasses import dataclass
6-
from typing import Optional, Tuple
6+
from typing import Any, Optional, Tuple
77

88
import torch
99
import torch.nn.functional as F
@@ -195,7 +195,7 @@ def __init__(self, model_args: ModelArgs):
195195
model_args.n_heads * self.head_dim, model_args.dim, bias=False
196196
)
197197

198-
def init_weights(self, init_std: float):
198+
def init_weights(self, init_std: float) -> None:
199199
for linear in (self.wq, self.wk, self.wv):
200200
nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02)
201201
nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std)
@@ -204,7 +204,7 @@ def forward(
204204
self,
205205
x: torch.Tensor,
206206
freqs_cis: torch.Tensor,
207-
):
207+
) -> Any:
208208
"""Forward pass of the attention module.
209209
210210
Args:
@@ -275,10 +275,10 @@ def __init__(
275275
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
276276
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
277277

278-
def forward(self, x):
278+
def forward(self, x) -> Any:
279279
return self.w2(F.silu(self.w1(x)) * self.w3(x))
280280

281-
def init_weights(self, init_std: float):
281+
def init_weights(self, init_std: float) -> None:
282282
nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02)
283283
for linear in (self.w2, self.w3):
284284
nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std)

examples/distributed_inference/tensor_parallel_llama3.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# Taken and modified pytorch lightening
2+
# https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning
3+
import logging
14
import os
25
import time
36

@@ -12,9 +15,6 @@
1215
)
1316
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
1417

15-
# Taken and modified pytorch lightening
16-
# https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning
17-
import logging
1818
_rank = int(os.environ["RANK"])
1919
_world_size = int(os.environ["WORLD_SIZE"])
2020
tp_size = 2
@@ -25,9 +25,6 @@
2525
fh.setLevel(logging.INFO)
2626
logger.addHandler(fh)
2727

28-
# understand world topology
29-
30-
3128
tp_mesh = init_device_mesh(device_type="cuda", mesh_shape=(_world_size,))
3229

3330
model_args = ModelArgs(
@@ -56,7 +53,7 @@
5653
"use_python_runtime": True,
5754
"workspace_size": 1 << 33,
5855
"debug": False,
59-
"timing_cache_path":"/opt/file/cache/timing_cache_llama.bin"
56+
"timing_cache_path": "/opt/file/cache/timing_cache_llama.bin",
6057
},
6158
dynamic=False,
6259
)

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -362,21 +362,10 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
362362
trt_modules = {}
363363
# Iterate over all components that can be accelerated
364364
# Generate the corresponding TRT Module for those
365-
logger.info(f"-" * 100)
366-
logger.info(f"There are {len(list(partitioned_module.named_children()))} submodules in total.")
367-
i = 0
368-
import os
369365
for name, _ in partitioned_module.named_children():
370-
# Benchmark log utilities
371-
i += 1
372-
logger.info(f"-" * 100)
373-
logger.info(f"Start compiling {i}th submodule")
374-
total = torch.cuda.get_device_properties(0).total_memory
375-
376366
submodule = getattr(partitioned_module, name)
377367
# Criteria for a module to be convertible to TRT
378368
if settings.use_fast_partitioner and "_run_on_acc" not in name:
379-
# if (settings.use_fast_partitioner and "_run_on_acc" not in name) or int(os.environ["RANK"]) == 1:
380369
dryrun_tracker.to_run_in_torch.extend(parse_non_trt_nodes(submodule))
381370
continue
382371

0 commit comments

Comments
 (0)