Skip to content

Commit f7405f2

Browse files
committed
update
1 parent b757035 commit f7405f2

File tree

3 files changed

+201
-141
lines changed

3 files changed

+201
-141
lines changed

tests/models/transformers/test_models_transformer_cogvideox.py

Lines changed: 96 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -13,59 +13,53 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import unittest
17-
1816
import torch
1917

2018
from diffusers import CogVideoXTransformer3DModel
21-
22-
from ...testing_utils import (
23-
enable_full_determinism,
24-
torch_device,
19+
from diffusers.utils.torch_utils import randn_tensor
20+
21+
from ...testing_utils import enable_full_determinism, torch_device
22+
from ..testing_utils import (
23+
BaseModelTesterConfig,
24+
ModelTesterMixin,
25+
TorchCompileTesterMixin,
26+
TrainingTesterMixin,
2527
)
26-
from ..test_modeling_common import ModelTesterMixin
2728

2829

2930
enable_full_determinism()
3031

3132

32-
class CogVideoXTransformerTests(ModelTesterMixin, unittest.TestCase):
33-
model_class = CogVideoXTransformer3DModel
34-
main_input_name = "hidden_states"
35-
uses_custom_attn_processor = True
36-
model_split_percents = [0.7, 0.7, 0.8]
33+
# ======================== CogVideoX ========================
3734

35+
36+
class CogVideoXTransformerTesterConfig(BaseModelTesterConfig):
3837
@property
39-
def dummy_input(self):
40-
batch_size = 2
41-
num_channels = 4
42-
num_frames = 1
43-
height = 8
44-
width = 8
45-
embedding_dim = 8
46-
sequence_length = 8
38+
def model_class(self):
39+
return CogVideoXTransformer3DModel
4740

48-
hidden_states = torch.randn((batch_size, num_frames, num_channels, height, width)).to(torch_device)
49-
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
50-
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
41+
@property
42+
def main_input_name(self) -> str:
43+
return "hidden_states"
5144

52-
return {
53-
"hidden_states": hidden_states,
54-
"encoder_hidden_states": encoder_hidden_states,
55-
"timestep": timestep,
56-
}
45+
@property
46+
def model_split_percents(self) -> list:
47+
return [0.7, 0.7, 0.8]
5748

5849
@property
59-
def input_shape(self):
50+
def output_shape(self) -> tuple:
6051
return (1, 4, 8, 8)
6152

6253
@property
63-
def output_shape(self):
54+
def input_shape(self) -> tuple:
6455
return (1, 4, 8, 8)
6556

66-
def prepare_init_args_and_inputs_for_common(self):
67-
init_dict = {
68-
# Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings.
57+
@property
58+
def generator(self):
59+
return torch.Generator("cpu").manual_seed(0)
60+
61+
def get_init_dict(self) -> dict:
62+
return {
6963
"num_attention_heads": 2,
7064
"attention_head_dim": 8,
7165
"in_channels": 4,
@@ -81,50 +75,66 @@ def prepare_init_args_and_inputs_for_common(self):
8175
"temporal_compression_ratio": 4,
8276
"max_text_seq_length": 8,
8377
}
84-
inputs_dict = self.dummy_input
85-
return init_dict, inputs_dict
8678

87-
def test_gradient_checkpointing_is_applied(self):
88-
expected_set = {"CogVideoXTransformer3DModel"}
89-
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
90-
91-
92-
class CogVideoX1_5TransformerTests(ModelTesterMixin, unittest.TestCase):
93-
model_class = CogVideoXTransformer3DModel
94-
main_input_name = "hidden_states"
95-
uses_custom_attn_processor = True
96-
97-
@property
98-
def dummy_input(self):
99-
batch_size = 2
79+
def get_dummy_inputs(self, batch_size: int = 2) -> dict[str, torch.Tensor]:
10080
num_channels = 4
101-
num_frames = 2
81+
num_frames = 1
10282
height = 8
10383
width = 8
10484
embedding_dim = 8
10585
sequence_length = 8
10686

107-
hidden_states = torch.randn((batch_size, num_frames, num_channels, height, width)).to(torch_device)
108-
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
109-
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
110-
11187
return {
112-
"hidden_states": hidden_states,
113-
"encoder_hidden_states": encoder_hidden_states,
114-
"timestep": timestep,
88+
"hidden_states": randn_tensor(
89+
(batch_size, num_frames, num_channels, height, width), generator=self.generator, device=torch_device
90+
),
91+
"encoder_hidden_states": randn_tensor(
92+
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
93+
),
94+
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
11595
}
11696

97+
98+
class TestCogVideoXTransformer(CogVideoXTransformerTesterConfig, ModelTesterMixin):
99+
pass
100+
101+
102+
class TestCogVideoXTransformerTraining(CogVideoXTransformerTesterConfig, TrainingTesterMixin):
103+
def test_gradient_checkpointing_is_applied(self):
104+
expected_set = {"CogVideoXTransformer3DModel"}
105+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
106+
107+
108+
class TestCogVideoXTransformerCompile(CogVideoXTransformerTesterConfig, TorchCompileTesterMixin):
109+
pass
110+
111+
112+
# ======================== CogVideoX 1.5 ========================
113+
114+
115+
class CogVideoX15TransformerTesterConfig(BaseModelTesterConfig):
117116
@property
118-
def input_shape(self):
117+
def model_class(self):
118+
return CogVideoXTransformer3DModel
119+
120+
@property
121+
def main_input_name(self) -> str:
122+
return "hidden_states"
123+
124+
@property
125+
def output_shape(self) -> tuple:
119126
return (1, 4, 8, 8)
120127

121128
@property
122-
def output_shape(self):
129+
def input_shape(self) -> tuple:
123130
return (1, 4, 8, 8)
124131

125-
def prepare_init_args_and_inputs_for_common(self):
126-
init_dict = {
127-
# Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings.
132+
@property
133+
def generator(self):
134+
return torch.Generator("cpu").manual_seed(0)
135+
136+
def get_init_dict(self) -> dict:
137+
return {
128138
"num_attention_heads": 2,
129139
"attention_head_dim": 8,
130140
"in_channels": 4,
@@ -141,9 +151,29 @@ def prepare_init_args_and_inputs_for_common(self):
141151
"max_text_seq_length": 8,
142152
"use_rotary_positional_embeddings": True,
143153
}
144-
inputs_dict = self.dummy_input
145-
return init_dict, inputs_dict
146154

147-
def test_gradient_checkpointing_is_applied(self):
148-
expected_set = {"CogVideoXTransformer3DModel"}
149-
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
155+
def get_dummy_inputs(self, batch_size: int = 2) -> dict[str, torch.Tensor]:
156+
num_channels = 4
157+
num_frames = 2
158+
height = 8
159+
width = 8
160+
embedding_dim = 8
161+
sequence_length = 8
162+
163+
return {
164+
"hidden_states": randn_tensor(
165+
(batch_size, num_frames, num_channels, height, width), generator=self.generator, device=torch_device
166+
),
167+
"encoder_hidden_states": randn_tensor(
168+
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
169+
),
170+
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
171+
}
172+
173+
174+
class TestCogVideoX15Transformer(CogVideoX15TransformerTesterConfig, ModelTesterMixin):
175+
pass
176+
177+
178+
class TestCogVideoX15TransformerCompile(CogVideoX15TransformerTesterConfig, TorchCompileTesterMixin):
179+
pass

tests/models/transformers/test_models_transformer_cogview3plus.py

Lines changed: 54 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -13,63 +13,50 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import unittest
17-
1816
import torch
1917

2018
from diffusers import CogView3PlusTransformer2DModel
19+
from diffusers.utils.torch_utils import randn_tensor
2120

22-
from ...testing_utils import (
23-
enable_full_determinism,
24-
torch_device,
21+
from ...testing_utils import enable_full_determinism, torch_device
22+
from ..testing_utils import (
23+
BaseModelTesterConfig,
24+
ModelTesterMixin,
25+
TorchCompileTesterMixin,
26+
TrainingTesterMixin,
2527
)
26-
from ..test_modeling_common import ModelTesterMixin
2728

2829

2930
enable_full_determinism()
3031

3132

32-
class CogView3PlusTransformerTests(ModelTesterMixin, unittest.TestCase):
33-
model_class = CogView3PlusTransformer2DModel
34-
main_input_name = "hidden_states"
35-
uses_custom_attn_processor = True
36-
model_split_percents = [0.7, 0.6, 0.6]
37-
33+
class CogView3PlusTransformerTesterConfig(BaseModelTesterConfig):
3834
@property
39-
def dummy_input(self):
40-
batch_size = 2
41-
num_channels = 4
42-
height = 8
43-
width = 8
44-
embedding_dim = 8
45-
sequence_length = 8
35+
def model_class(self):
36+
return CogView3PlusTransformer2DModel
4637

47-
hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
48-
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
49-
original_size = torch.tensor([height * 8, width * 8]).unsqueeze(0).repeat(batch_size, 1).to(torch_device)
50-
target_size = torch.tensor([height * 8, width * 8]).unsqueeze(0).repeat(batch_size, 1).to(torch_device)
51-
crop_coords = torch.tensor([0, 0]).unsqueeze(0).repeat(batch_size, 1).to(torch_device)
52-
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
38+
@property
39+
def main_input_name(self) -> str:
40+
return "hidden_states"
5341

54-
return {
55-
"hidden_states": hidden_states,
56-
"encoder_hidden_states": encoder_hidden_states,
57-
"original_size": original_size,
58-
"target_size": target_size,
59-
"crop_coords": crop_coords,
60-
"timestep": timestep,
61-
}
42+
@property
43+
def model_split_percents(self) -> list:
44+
return [0.7, 0.6, 0.6]
6245

6346
@property
64-
def input_shape(self):
47+
def output_shape(self) -> tuple:
6548
return (1, 4, 8, 8)
6649

6750
@property
68-
def output_shape(self):
51+
def input_shape(self) -> tuple:
6952
return (1, 4, 8, 8)
7053

71-
def prepare_init_args_and_inputs_for_common(self):
72-
init_dict = {
54+
@property
55+
def generator(self):
56+
return torch.Generator("cpu").manual_seed(0)
57+
58+
def get_init_dict(self) -> dict:
59+
return {
7360
"patch_size": 2,
7461
"in_channels": 4,
7562
"num_layers": 2,
@@ -82,9 +69,37 @@ def prepare_init_args_and_inputs_for_common(self):
8269
"pos_embed_max_size": 8,
8370
"sample_size": 8,
8471
}
85-
inputs_dict = self.dummy_input
86-
return init_dict, inputs_dict
8772

73+
def get_dummy_inputs(self, batch_size: int = 2) -> dict[str, torch.Tensor]:
74+
num_channels = 4
75+
height = 8
76+
width = 8
77+
embedding_dim = 8
78+
sequence_length = 8
79+
80+
return {
81+
"hidden_states": randn_tensor(
82+
(batch_size, num_channels, height, width), generator=self.generator, device=torch_device
83+
),
84+
"encoder_hidden_states": randn_tensor(
85+
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
86+
),
87+
"original_size": torch.tensor([height * 8, width * 8]).unsqueeze(0).repeat(batch_size, 1).to(torch_device),
88+
"target_size": torch.tensor([height * 8, width * 8]).unsqueeze(0).repeat(batch_size, 1).to(torch_device),
89+
"crop_coords": torch.tensor([0, 0]).unsqueeze(0).repeat(batch_size, 1).to(torch_device),
90+
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
91+
}
92+
93+
94+
class TestCogView3PlusTransformer(CogView3PlusTransformerTesterConfig, ModelTesterMixin):
95+
pass
96+
97+
98+
class TestCogView3PlusTransformerTraining(CogView3PlusTransformerTesterConfig, TrainingTesterMixin):
8899
def test_gradient_checkpointing_is_applied(self):
89100
expected_set = {"CogView3PlusTransformer2DModel"}
90101
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
102+
103+
104+
class TestCogView3PlusTransformerCompile(CogView3PlusTransformerTesterConfig, TorchCompileTesterMixin):
105+
pass

0 commit comments

Comments
 (0)