1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- import os
1615import subprocess
1716import tempfile
1817import textwrap
1918
20- # TORCH_LOGS=+dtensor CUDA_LAUNCH_BLOCKING=1 TORCH_USE_CUDA_DSA=1 PYTHONPATH="src" python -m torch.distributed.run --nproc_per_node 2 ./tests/tp/test_tp.py
2119from transformers import is_torch_available
22- from transformers .models .llama .configuration_llama import LlamaConfig
23- from transformers .models .llama .modeling_llama import LlamaModel
2420from transformers .testing_utils import (
2521 TestCasePlus ,
26- execute_subprocess_async ,
2722 get_torch_dist_unique_port ,
2823 require_torch_multi_gpu ,
2924)
3328 import torch
3429
3530
31+ # RUN_SLOW=1 pytest -sv tests/tensor_parallel/test_tensor_parallel.py
3632class TestTensorParallel (TestCasePlus ):
33+ nproc_per_node = 2
34+
3735 def torchrun (self , script : str ):
3836 """Run the `script` using `torchrun` command for multi-processing in a subprocess. Captures errors as necessary."""
3937 with tempfile .NamedTemporaryFile (mode = "w+" , suffix = ".py" ) as tmp :
4038 tmp .write (script )
4139 tmp .flush ()
4240 tmp .seek (0 )
4341 cmd = (
44- f"torchrun --nproc_per_node { torch . cuda . device_count () } --master_port { get_torch_dist_unique_port ()} { tmp .name } "
42+ f"torchrun --nproc_per_node { self . nproc_per_node } --master_port { get_torch_dist_unique_port ()} { tmp .name } "
4543 ).split ()
4644
4745 # Note that the subprocess will be waited for here, and raise an error if not successful
@@ -50,44 +48,39 @@ def torchrun(self, script: str):
5048 except subprocess .CalledProcessError as e :
5149 raise Exception (f"The following error was captured: { e .stderr } " )
5250
53- @require_torch_multi_gpu
54- def test_tp (self ):
55- distributed_args = f"""--nproc_per_node={ torch .cuda .device_count ()}
56- --master_port={ get_torch_dist_unique_port ()}
57- { self .test_file_dir } /test_tp.py
58- """ .split ()
59- output_dir = self .get_auto_remove_tmp_dir ()
60- args = f"--output_dir { output_dir } --report_to none" .split ()
61- cmd = ["torchrun" ] + distributed_args + args
62- print (cmd )
63- execute_subprocess_async (cmd , env = self .get_env ())
64- # successful return here == success - any errors would have caused an error in the sub-call
65-
66- @require_torch_multi_gpu
67- def test_loading_memory_consumption (self ):
51+ def test_model_forward (self ):
6852 script_to_run = textwrap .dedent (
6953 """
7054 import torch
7155 import os
72- from transformers import AutoModelForCausalLM
56+ from transformers import AutoModelForCausalLM, AutoTokenizer
7357
74- model_id = "meta-llama/Meta-Llama-3-8B-Instruct "
58+ model_id = "JackFram/llama-68m "
7559
7660 rank = int(os.environ["RANK"])
7761 world_size = int(os.environ["WORLD_SIZE"])
78- device = torch.device(f"cuda:{rank}")
79- torch.distributed.init_process_group("nccl", device_id=device)
8062
81- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16 , tp_plan="auto")
63+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto" , tp_plan="auto")
8264 torch.distributed.barrier()
8365
84- # The expected model memory footprint. We add 1 as not all the modules are split (e.g. the embeddings)
85- expected_model_memory_per_device = (16 / world_size) + 1
86- overhead_factor = 1.2
66+ has_dtensor = 0
67+ for name, parameter in model.named_parameters():
68+ if isinstance(parameter.data, torch.distributed.tensor.DTensor):
69+ has_dtensor = 1
70+ break
71+
72+ assert has_dtensor == 1, "TP model must has DTensor"
73+
74+ tokenizer = AutoTokenizer.from_pretrained(model_id)
75+ prompt = "Can I help"
76+
77+ inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
78+ outputs = model(inputs)
8779
88- # Check that we do not use more than the expected sharded size during initialization
89- if torch.cuda.max_memory_allocated(device) / 1024**3 > expected_model_memory_per_device * overhead_factor:
90- raise ValueError("Loading the model used more than the expected fraction of model size per device")
80+ next_token_logits = outputs[0][:, -1, :]
81+ next_token = torch.argmax(next_token_logits, dim=-1)
82+ response = tokenizer.decode(next_token)
83+ assert response == "with"
9184
9285 torch.distributed.barrier()
9386 torch.distributed.destroy_process_group()
@@ -96,69 +89,6 @@ def test_loading_memory_consumption(self):
9689 self .torchrun (script_to_run )
9790
9891
99- if __name__ == "__main__" :
100- # The script below is meant to be run under torch.distributed, on a machine with multiple GPUs:
101- # CUDA_VISIBLE_DEVICES=0,1 RUN_SLOW=1 pytest -sv tests/tp/test_tp.py
102- # or
103- # PYTHONPATH="src" python -m torch.distributed.run --nproc_per_node 2 ./tests/tp/test_tp.py
104-
105- if not is_torch_available ():
106- exit (0 )
107-
108- # Test settings
109- model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
110- bs = 1
111- seqlen = 4096
112- # Get distributed settings
113- rank = int (os .environ ["RANK" ])
114- world_size = int (os .environ ["WORLD_SIZE" ])
115-
116- # Initialize distributed
117- device = torch .device (f"cuda:{ rank } " )
118- torch .distributed .init_process_group ("nccl" , device_id = device )
119- device_mesh = torch .distributed .init_device_mesh ("cuda" , (world_size ,))
120-
121- # Get model config
122- config = LlamaConfig .from_pretrained (model_id )
123- config .hidden_size = 2048
124- config .attention_bias = False
125- # Instantiate model
126- with device :
127- model = LlamaModel (config ).to (dtype = torch .float16 )
128-
129- model .eval ()
130- # Tensor Parallel
131- if world_size > 1 :
132- model .tensor_parallel (device_mesh )
133- # Run model
134-
135- inputs = torch .randint (config .vocab_size , (bs , seqlen ), device = device )
136-
137- # Test cuda graphing explicitly
138- with torch .cuda .device (device ):
139- print ("Cuda graphing" )
140- with torch .no_grad ():
141- inputs = torch .randint (config .vocab_size , (bs , seqlen ), device = device )
142- # CUDA Graph setup
143- s = torch .cuda .Stream (device = device )
144- s .wait_stream (torch .cuda .current_stream ())
145- with torch .cuda .stream (s ):
146- for i in range (3 ):
147- out = model (inputs )
148- torch .cuda .current_stream ().wait_stream (s )
149- g = torch .cuda .CUDAGraph ()
150- with torch .cuda .graph (g ):
151- out = model (inputs )
152-
153- for _ in range (2 ):
154- g .replay ()
155- s .synchronize ()
156-
157- assert out .last_hidden_state .shape == torch .Size ([bs , seqlen , config .hidden_size ])
158-
159- # Test compile
160- with torch .no_grad ():
161- out = model (inputs )
162- model .forward = torch .compile (model .forward , mode = "reduce-overhead" )
163- out = model (inputs )
164- out = model (inputs )
92+ @require_torch_multi_gpu
93+ class TestTensorParallelCuda (TestTensorParallel ):
94+ nproc_per_node = torch .cuda .device_count ()
0 commit comments