7
7
8
8
# pyre-strict
9
9
10
+ import copy
10
11
import random
11
12
from dataclasses import dataclass
12
13
from typing import Any , cast , Dict , List , Optional , Tuple , Type , Union
@@ -239,10 +240,16 @@ def _validate_pooling_factor(
239
240
else None
240
241
)
241
242
242
- global_float = torch .rand (
243
- (batch_size * world_size , num_float_features ), device = device
244
- )
245
- global_label = torch .rand (batch_size * world_size , device = device )
243
+ if randomize_indices :
244
+ global_float = torch .rand (
245
+ (batch_size * world_size , num_float_features ), device = device
246
+ )
247
+ global_label = torch .rand (batch_size * world_size , device = device )
248
+ else :
249
+ global_float = torch .zeros (
250
+ (batch_size * world_size , num_float_features ), device = device
251
+ )
252
+ global_label = torch .zeros (batch_size * world_size , device = device )
246
253
247
254
# Split global batch into local batches.
248
255
local_inputs = []
@@ -939,6 +946,7 @@ def __init__(
939
946
max_feature_lengths_list : Optional [List [Dict [str , int ]]] = None ,
940
947
feature_processor_modules : Optional [Dict [str , torch .nn .Module ]] = None ,
941
948
over_arch_clazz : Type [nn .Module ] = TestOverArch ,
949
+ preproc_module : Optional [nn .Module ] = None ,
942
950
) -> None :
943
951
super ().__init__ (
944
952
tables = cast (List [BaseEmbeddingConfig ], tables ),
@@ -960,13 +968,22 @@ def __init__(
960
968
embedding_names = (
961
969
list (embedding_groups .values ())[0 ] if embedding_groups else None
962
970
)
971
+ self ._embedding_names : List [str ] = (
972
+ embedding_names
973
+ if embedding_names
974
+ else [feature for table in tables for feature in table .feature_names ]
975
+ )
976
+ self ._weighted_features : List [str ] = [
977
+ feature for table in weighted_tables for feature in table .feature_names
978
+ ]
963
979
self .over : nn .Module = over_arch_clazz (
964
980
tables , weighted_tables , embedding_names , dense_device
965
981
)
966
982
self .register_buffer (
967
983
"dummy_ones" ,
968
984
torch .ones (1 , device = dense_device ),
969
985
)
986
+ self .preproc_module = preproc_module
970
987
971
988
def sparse_forward (self , input : ModelInput ) -> KeyedTensor :
972
989
return self .sparse (
@@ -993,6 +1010,8 @@ def forward(
993
1010
self ,
994
1011
input : ModelInput ,
995
1012
) -> Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]]:
1013
+ if self .preproc_module :
1014
+ input = self .preproc_module (input )
996
1015
return self .dense_forward (input , self .sparse_forward (input ))
997
1016
998
1017
@@ -1409,3 +1428,246 @@ def _post_ebc_test_wrap_function(kt: KeyedTensor) -> KeyedTensor:
1409
1428
continue
1410
1429
1411
1430
return kt
1431
+
1432
+
1433
+ class TestPreprocNonWeighted (nn .Module ):
1434
+ """
1435
+ Basic module for testing
1436
+
1437
+ Args: None
1438
+ Examples:
1439
+ >>> TestPreprocNonWeighted()
1440
+ Returns:
1441
+ List[KeyedJaggedTensor]
1442
+ """
1443
+
1444
+ def forward (self , kjt : KeyedJaggedTensor ) -> List [KeyedJaggedTensor ]:
1445
+ """
1446
+ Selects 3 features from a specific KJT
1447
+ """
1448
+ # split
1449
+ jt_0 = kjt ["feature_0" ]
1450
+ jt_1 = kjt ["feature_1" ]
1451
+ jt_2 = kjt ["feature_2" ]
1452
+
1453
+ # merge only features 0,1,2, removing feature 3
1454
+ return [
1455
+ KeyedJaggedTensor .from_jt_dict (
1456
+ {
1457
+ "feature_0" : jt_0 ,
1458
+ "feature_1" : jt_1 ,
1459
+ "feature_2" : jt_2 ,
1460
+ }
1461
+ )
1462
+ ]
1463
+
1464
+
1465
+ class TestPreprocWeighted (nn .Module ):
1466
+ """
1467
+ Basic module for testing
1468
+
1469
+ Args: None
1470
+ Examples:
1471
+ >>> TestPreprocWeighted()
1472
+ Returns:
1473
+ List[KeyedJaggedTensor]
1474
+ """
1475
+
1476
+ def forward (self , kjt : KeyedJaggedTensor ) -> List [KeyedJaggedTensor ]:
1477
+ """
1478
+ Selects 1 feature from specific weighted KJT
1479
+ """
1480
+
1481
+ # split
1482
+ jt_0 = kjt ["weighted_feature_0" ]
1483
+
1484
+ # keep only weighted_feature_0
1485
+ return [
1486
+ KeyedJaggedTensor .from_jt_dict (
1487
+ {
1488
+ "weighted_feature_0" : jt_0 ,
1489
+ }
1490
+ )
1491
+ ]
1492
+
1493
+
1494
+ class TestModelWithPreproc (nn .Module ):
1495
+ """
1496
+ Basic module with up to 3 preproc modules:
1497
+ - preproc on idlist_features for non-weighted EBC
1498
+ - preproc on idscore_features for weighted EBC
1499
+ - optional preproc on model input shared by both EBCs
1500
+
1501
+ Args:
1502
+ tables,
1503
+ weighted_tables,
1504
+ device,
1505
+ preproc_module,
1506
+ num_float_features,
1507
+ run_preproc_inline,
1508
+
1509
+ Example:
1510
+ >>> TestModelWithPreproc(tables, weighted_tables, device)
1511
+
1512
+ Returns:
1513
+ Tuple[torch.Tensor, torch.Tensor]
1514
+ """
1515
+
1516
+ def __init__ (
1517
+ self ,
1518
+ tables : List [EmbeddingBagConfig ],
1519
+ weighted_tables : List [EmbeddingBagConfig ],
1520
+ device : torch .device ,
1521
+ preproc_module : Optional [nn .Module ] = None ,
1522
+ num_float_features : int = 10 ,
1523
+ run_preproc_inline : bool = False ,
1524
+ ) -> None :
1525
+ super ().__init__ ()
1526
+ self .dense = TestDenseArch (num_float_features , device )
1527
+
1528
+ self .ebc : EmbeddingBagCollection = EmbeddingBagCollection (
1529
+ tables = tables ,
1530
+ device = device ,
1531
+ )
1532
+ self .weighted_ebc = EmbeddingBagCollection (
1533
+ tables = weighted_tables ,
1534
+ is_weighted = True ,
1535
+ device = device ,
1536
+ )
1537
+ self .preproc_nonweighted = TestPreprocNonWeighted ()
1538
+ self .preproc_weighted = TestPreprocWeighted ()
1539
+ self ._preproc_module = preproc_module
1540
+ self ._run_preproc_inline = run_preproc_inline
1541
+
1542
+ def forward (
1543
+ self ,
1544
+ input : ModelInput ,
1545
+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
1546
+ """
1547
+ Runs preprco for EBC and weighted EBC, optionally runs preproc for input
1548
+
1549
+ Args:
1550
+ input
1551
+ Returns:
1552
+ Tuple[torch.Tensor, torch.Tensor]
1553
+ """
1554
+ modified_input = input
1555
+
1556
+ if self ._preproc_module is not None :
1557
+ modified_input = self ._preproc_module (modified_input )
1558
+ elif self ._run_preproc_inline :
1559
+ modified_input .idlist_features = KeyedJaggedTensor .from_lengths_sync (
1560
+ modified_input .idlist_features .keys (),
1561
+ modified_input .idlist_features .values (),
1562
+ modified_input .idlist_features .lengths (),
1563
+ )
1564
+
1565
+ modified_idlist_features = self .preproc_nonweighted (
1566
+ modified_input .idlist_features
1567
+ )
1568
+ modified_idscore_features = self .preproc_weighted (
1569
+ modified_input .idscore_features
1570
+ )
1571
+ ebc_out = self .ebc (modified_idlist_features [0 ])
1572
+ weighted_ebc_out = self .weighted_ebc (modified_idscore_features [0 ])
1573
+
1574
+ pred = torch .cat ([ebc_out .values (), weighted_ebc_out .values ()], dim = 1 )
1575
+ return pred .sum (), pred
1576
+
1577
+
1578
+ class TestNegSamplingModule (torch .nn .Module ):
1579
+ """
1580
+ Basic module to simulate feature augmentation preproc (e.g. neg sampling) for testing
1581
+
1582
+ Args:
1583
+ extra_input
1584
+ has_params
1585
+
1586
+ Example:
1587
+ >>> preproc = TestNegSamplingModule(extra_input)
1588
+ >>> out = preproc(in)
1589
+
1590
+ Returns:
1591
+ ModelInput
1592
+ """
1593
+
1594
+ def __init__ (
1595
+ self ,
1596
+ extra_input : ModelInput ,
1597
+ has_params : bool = False ,
1598
+ ) -> None :
1599
+ super ().__init__ ()
1600
+ self ._extra_input = extra_input
1601
+ if has_params :
1602
+ self ._linear : nn .Module = nn .Linear (30 , 30 )
1603
+
1604
+ def forward (self , input : ModelInput ) -> ModelInput :
1605
+ """
1606
+ Appends extra features to model input
1607
+
1608
+ Args:
1609
+ input
1610
+ Returns:
1611
+ ModelInput
1612
+ """
1613
+
1614
+ # merge extra input
1615
+ modified_input = copy .deepcopy (input )
1616
+
1617
+ # dim=0 (batch dimensions) increases by self._extra_input.float_features.shape[0]
1618
+ modified_input .float_features = torch .concat (
1619
+ (modified_input .float_features , self ._extra_input .float_features ), dim = 0
1620
+ )
1621
+
1622
+ # stride will be same but features will be joined
1623
+ modified_input .idlist_features = KeyedJaggedTensor .concat (
1624
+ [modified_input .idlist_features , self ._extra_input .idlist_features ]
1625
+ )
1626
+ if self ._extra_input .idscore_features is not None :
1627
+ # stride will be smae but features will be joined
1628
+ modified_input .idscore_features = KeyedJaggedTensor .concat (
1629
+ # pyre-ignore
1630
+ [modified_input .idscore_features , self ._extra_input .idscore_features ]
1631
+ )
1632
+
1633
+ # dim=0 (batch dimensions) increases by self._extra_input.input_label.shape[0]
1634
+ modified_input .label = torch .concat (
1635
+ (modified_input .label , self ._extra_input .label ), dim = 0
1636
+ )
1637
+
1638
+ return modified_input
1639
+
1640
+
1641
+ class TestPositionWeightedPreprocModule (torch .nn .Module ):
1642
+ """
1643
+ Basic module for testing
1644
+
1645
+ Args: None
1646
+ Example:
1647
+ >>> preproc = TestPositionWeightedPreprocModule(max_feature_lengths, device)
1648
+ >>> out = preproc(in)
1649
+ Returns:
1650
+ ModelInput
1651
+ """
1652
+
1653
+ def __init__ (
1654
+ self , max_feature_lengths : Dict [str , int ], device : torch .device
1655
+ ) -> None :
1656
+ super ().__init__ ()
1657
+ self .fp_proc = PositionWeightedProcessor (
1658
+ max_feature_lengths = max_feature_lengths ,
1659
+ device = device ,
1660
+ )
1661
+
1662
+ def forward (self , input : ModelInput ) -> ModelInput :
1663
+ """
1664
+ Runs PositionWeightedProcessor
1665
+
1666
+ Args:
1667
+ input
1668
+ Returns:
1669
+ ModelInput
1670
+ """
1671
+ modified_input = copy .deepcopy (input )
1672
+ modified_input .idlist_features = self .fp_proc (modified_input .idlist_features )
1673
+ return modified_input
0 commit comments