1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16- import unittest
17-
1816import torch
1917
2018from diffusers import CogVideoXTransformer3DModel
21-
22- from ...testing_utils import (
23- enable_full_determinism ,
24- torch_device ,
19+ from diffusers .utils .torch_utils import randn_tensor
20+
21+ from ...testing_utils import enable_full_determinism , torch_device
22+ from ..testing_utils import (
23+ BaseModelTesterConfig ,
24+ ModelTesterMixin ,
25+ TorchCompileTesterMixin ,
26+ TrainingTesterMixin ,
2527)
26- from ..test_modeling_common import ModelTesterMixin
2728
2829
2930enable_full_determinism ()
3031
3132
32- class CogVideoXTransformerTests (ModelTesterMixin , unittest .TestCase ):
33- model_class = CogVideoXTransformer3DModel
34- main_input_name = "hidden_states"
35- uses_custom_attn_processor = True
36- model_split_percents = [0.7 , 0.7 , 0.8 ]
33+ # ======================== CogVideoX ========================
3734
35+
36+ class CogVideoXTransformerTesterConfig (BaseModelTesterConfig ):
3837 @property
39- def dummy_input (self ):
40- batch_size = 2
41- num_channels = 4
42- num_frames = 1
43- height = 8
44- width = 8
45- embedding_dim = 8
46- sequence_length = 8
38+ def model_class (self ):
39+ return CogVideoXTransformer3DModel
4740
48- hidden_states = torch . randn (( batch_size , num_frames , num_channels , height , width )). to ( torch_device )
49- encoder_hidden_states = torch . randn (( batch_size , sequence_length , embedding_dim )). to ( torch_device )
50- timestep = torch . randint ( 0 , 1000 , size = ( batch_size ,)). to ( torch_device )
41+ @ property
42+ def main_input_name ( self ) -> str :
43+ return "hidden_states"
5144
52- return {
53- "hidden_states" : hidden_states ,
54- "encoder_hidden_states" : encoder_hidden_states ,
55- "timestep" : timestep ,
56- }
45+ @property
46+ def model_split_percents (self ) -> list :
47+ return [0.7 , 0.7 , 0.8 ]
5748
5849 @property
59- def input_shape (self ):
50+ def output_shape (self ) -> tuple :
6051 return (1 , 4 , 8 , 8 )
6152
6253 @property
63- def output_shape (self ):
54+ def input_shape (self ) -> tuple :
6455 return (1 , 4 , 8 , 8 )
6556
66- def prepare_init_args_and_inputs_for_common (self ):
67- init_dict = {
68- # Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings.
57+ @property
58+ def generator (self ):
59+ return torch .Generator ("cpu" ).manual_seed (0 )
60+
61+ def get_init_dict (self ) -> dict :
62+ return {
6963 "num_attention_heads" : 2 ,
7064 "attention_head_dim" : 8 ,
7165 "in_channels" : 4 ,
@@ -81,50 +75,66 @@ def prepare_init_args_and_inputs_for_common(self):
8175 "temporal_compression_ratio" : 4 ,
8276 "max_text_seq_length" : 8 ,
8377 }
84- inputs_dict = self .dummy_input
85- return init_dict , inputs_dict
8678
87- def test_gradient_checkpointing_is_applied (self ):
88- expected_set = {"CogVideoXTransformer3DModel" }
89- super ().test_gradient_checkpointing_is_applied (expected_set = expected_set )
90-
91-
92- class CogVideoX1_5TransformerTests (ModelTesterMixin , unittest .TestCase ):
93- model_class = CogVideoXTransformer3DModel
94- main_input_name = "hidden_states"
95- uses_custom_attn_processor = True
96-
97- @property
98- def dummy_input (self ):
99- batch_size = 2
79+ def get_dummy_inputs (self , batch_size : int = 2 ) -> dict [str , torch .Tensor ]:
10080 num_channels = 4
101- num_frames = 2
81+ num_frames = 1
10282 height = 8
10383 width = 8
10484 embedding_dim = 8
10585 sequence_length = 8
10686
107- hidden_states = torch .randn ((batch_size , num_frames , num_channels , height , width )).to (torch_device )
108- encoder_hidden_states = torch .randn ((batch_size , sequence_length , embedding_dim )).to (torch_device )
109- timestep = torch .randint (0 , 1000 , size = (batch_size ,)).to (torch_device )
110-
11187 return {
112- "hidden_states" : hidden_states ,
113- "encoder_hidden_states" : encoder_hidden_states ,
114- "timestep" : timestep ,
88+ "hidden_states" : randn_tensor (
89+ (batch_size , num_frames , num_channels , height , width ), generator = self .generator , device = torch_device
90+ ),
91+ "encoder_hidden_states" : randn_tensor (
92+ (batch_size , sequence_length , embedding_dim ), generator = self .generator , device = torch_device
93+ ),
94+ "timestep" : torch .randint (0 , 1000 , size = (batch_size ,), generator = self .generator ).to (torch_device ),
11595 }
11696
97+
98+ class TestCogVideoXTransformer (CogVideoXTransformerTesterConfig , ModelTesterMixin ):
99+ pass
100+
101+
102+ class TestCogVideoXTransformerTraining (CogVideoXTransformerTesterConfig , TrainingTesterMixin ):
103+ def test_gradient_checkpointing_is_applied (self ):
104+ expected_set = {"CogVideoXTransformer3DModel" }
105+ super ().test_gradient_checkpointing_is_applied (expected_set = expected_set )
106+
107+
108+ class TestCogVideoXTransformerCompile (CogVideoXTransformerTesterConfig , TorchCompileTesterMixin ):
109+ pass
110+
111+
112+ # ======================== CogVideoX 1.5 ========================
113+
114+
115+ class CogVideoX15TransformerTesterConfig (BaseModelTesterConfig ):
117116 @property
118- def input_shape (self ):
117+ def model_class (self ):
118+ return CogVideoXTransformer3DModel
119+
120+ @property
121+ def main_input_name (self ) -> str :
122+ return "hidden_states"
123+
124+ @property
125+ def output_shape (self ) -> tuple :
119126 return (1 , 4 , 8 , 8 )
120127
121128 @property
122- def output_shape (self ):
129+ def input_shape (self ) -> tuple :
123130 return (1 , 4 , 8 , 8 )
124131
125- def prepare_init_args_and_inputs_for_common (self ):
126- init_dict = {
127- # Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings.
132+ @property
133+ def generator (self ):
134+ return torch .Generator ("cpu" ).manual_seed (0 )
135+
136+ def get_init_dict (self ) -> dict :
137+ return {
128138 "num_attention_heads" : 2 ,
129139 "attention_head_dim" : 8 ,
130140 "in_channels" : 4 ,
@@ -141,9 +151,29 @@ def prepare_init_args_and_inputs_for_common(self):
141151 "max_text_seq_length" : 8 ,
142152 "use_rotary_positional_embeddings" : True ,
143153 }
144- inputs_dict = self .dummy_input
145- return init_dict , inputs_dict
146154
147- def test_gradient_checkpointing_is_applied (self ):
148- expected_set = {"CogVideoXTransformer3DModel" }
149- super ().test_gradient_checkpointing_is_applied (expected_set = expected_set )
155+ def get_dummy_inputs (self , batch_size : int = 2 ) -> dict [str , torch .Tensor ]:
156+ num_channels = 4
157+ num_frames = 2
158+ height = 8
159+ width = 8
160+ embedding_dim = 8
161+ sequence_length = 8
162+
163+ return {
164+ "hidden_states" : randn_tensor (
165+ (batch_size , num_frames , num_channels , height , width ), generator = self .generator , device = torch_device
166+ ),
167+ "encoder_hidden_states" : randn_tensor (
168+ (batch_size , sequence_length , embedding_dim ), generator = self .generator , device = torch_device
169+ ),
170+ "timestep" : torch .randint (0 , 1000 , size = (batch_size ,), generator = self .generator ).to (torch_device ),
171+ }
172+
173+
174+ class TestCogVideoX15Transformer (CogVideoX15TransformerTesterConfig , ModelTesterMixin ):
175+ pass
176+
177+
178+ class TestCogVideoX15TransformerCompile (CogVideoX15TransformerTesterConfig , TorchCompileTesterMixin ):
179+ pass
0 commit comments