Skip to content

Commit eab0f86

Browse files
sarckkfacebook-github-bot
authored andcommitted
Add ability to specify pipelineable preproc modules to ignore during SDD model rewrite (#2149)
Summary: Pull Request resolved: #2149 Make torchrec automatically pipeline any modules that don't have trainable params during sparse data dist pipelining. tldr; with some traversal logic changes, TorchRec sparse data dist pipeline can support arbitrary input transformations at input dist stage as long as they are composed of either nn.Module calls or currently supported ops (mainly getattr and getitem) Differential Revision: D57944338
1 parent 3eff5ce commit eab0f86

File tree

6 files changed

+1158
-35
lines changed

6 files changed

+1158
-35
lines changed

torchrec/distributed/test_utils/test_model.py

Lines changed: 266 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
# pyre-strict
99

10+
import copy
1011
import random
1112
from dataclasses import dataclass
1213
from typing import Any, cast, Dict, List, Optional, Tuple, Type, Union
@@ -239,10 +240,16 @@ def _validate_pooling_factor(
239240
else None
240241
)
241242

242-
global_float = torch.rand(
243-
(batch_size * world_size, num_float_features), device=device
244-
)
245-
global_label = torch.rand(batch_size * world_size, device=device)
243+
if randomize_indices:
244+
global_float = torch.rand(
245+
(batch_size * world_size, num_float_features), device=device
246+
)
247+
global_label = torch.rand(batch_size * world_size, device=device)
248+
else:
249+
global_float = torch.zeros(
250+
(batch_size * world_size, num_float_features), device=device
251+
)
252+
global_label = torch.zeros(batch_size * world_size, device=device)
246253

247254
# Split global batch into local batches.
248255
local_inputs = []
@@ -939,6 +946,7 @@ def __init__(
939946
max_feature_lengths_list: Optional[List[Dict[str, int]]] = None,
940947
feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None,
941948
over_arch_clazz: Type[nn.Module] = TestOverArch,
949+
preproc_module: Optional[nn.Module] = None,
942950
) -> None:
943951
super().__init__(
944952
tables=cast(List[BaseEmbeddingConfig], tables),
@@ -960,13 +968,22 @@ def __init__(
960968
embedding_names = (
961969
list(embedding_groups.values())[0] if embedding_groups else None
962970
)
971+
self._embedding_names: List[str] = (
972+
embedding_names
973+
if embedding_names
974+
else [feature for table in tables for feature in table.feature_names]
975+
)
976+
self._weighted_features: List[str] = [
977+
feature for table in weighted_tables for feature in table.feature_names
978+
]
963979
self.over: nn.Module = over_arch_clazz(
964980
tables, weighted_tables, embedding_names, dense_device
965981
)
966982
self.register_buffer(
967983
"dummy_ones",
968984
torch.ones(1, device=dense_device),
969985
)
986+
self.preproc_module = preproc_module
970987

971988
def sparse_forward(self, input: ModelInput) -> KeyedTensor:
972989
return self.sparse(
@@ -993,6 +1010,8 @@ def forward(
9931010
self,
9941011
input: ModelInput,
9951012
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
1013+
if self.preproc_module:
1014+
input = self.preproc_module(input)
9961015
return self.dense_forward(input, self.sparse_forward(input))
9971016

9981017

@@ -1409,3 +1428,246 @@ def _post_ebc_test_wrap_function(kt: KeyedTensor) -> KeyedTensor:
14091428
continue
14101429

14111430
return kt
1431+
1432+
1433+
class TestPreprocNonWeighted(nn.Module):
1434+
"""
1435+
Basic module for testing
1436+
1437+
Args: None
1438+
Examples:
1439+
>>> TestPreprocNonWeighted()
1440+
Returns:
1441+
List[KeyedJaggedTensor]
1442+
"""
1443+
1444+
def forward(self, kjt: KeyedJaggedTensor) -> List[KeyedJaggedTensor]:
1445+
"""
1446+
Selects 3 features from a specific KJT
1447+
"""
1448+
# split
1449+
jt_0 = kjt["feature_0"]
1450+
jt_1 = kjt["feature_1"]
1451+
jt_2 = kjt["feature_2"]
1452+
1453+
# merge only features 0,1,2, removing feature 3
1454+
return [
1455+
KeyedJaggedTensor.from_jt_dict(
1456+
{
1457+
"feature_0": jt_0,
1458+
"feature_1": jt_1,
1459+
"feature_2": jt_2,
1460+
}
1461+
)
1462+
]
1463+
1464+
1465+
class TestPreprocWeighted(nn.Module):
1466+
"""
1467+
Basic module for testing
1468+
1469+
Args: None
1470+
Examples:
1471+
>>> TestPreprocWeighted()
1472+
Returns:
1473+
List[KeyedJaggedTensor]
1474+
"""
1475+
1476+
def forward(self, kjt: KeyedJaggedTensor) -> List[KeyedJaggedTensor]:
1477+
"""
1478+
Selects 1 feature from specific weighted KJT
1479+
"""
1480+
1481+
# split
1482+
jt_0 = kjt["weighted_feature_0"]
1483+
1484+
# keep only weighted_feature_0
1485+
return [
1486+
KeyedJaggedTensor.from_jt_dict(
1487+
{
1488+
"weighted_feature_0": jt_0,
1489+
}
1490+
)
1491+
]
1492+
1493+
1494+
class TestModelWithPreproc(nn.Module):
1495+
"""
1496+
Basic module with up to 3 preproc modules:
1497+
- preproc on idlist_features for non-weighted EBC
1498+
- preproc on idscore_features for weighted EBC
1499+
- optional preproc on model input shared by both EBCs
1500+
1501+
Args:
1502+
tables,
1503+
weighted_tables,
1504+
device,
1505+
preproc_module,
1506+
num_float_features,
1507+
run_preproc_inline,
1508+
1509+
Example:
1510+
>>> TestModelWithPreproc(tables, weighted_tables, device)
1511+
1512+
Returns:
1513+
Tuple[torch.Tensor, torch.Tensor]
1514+
"""
1515+
1516+
def __init__(
1517+
self,
1518+
tables: List[EmbeddingBagConfig],
1519+
weighted_tables: List[EmbeddingBagConfig],
1520+
device: torch.device,
1521+
preproc_module: Optional[nn.Module] = None,
1522+
num_float_features: int = 10,
1523+
run_preproc_inline: bool = False,
1524+
) -> None:
1525+
super().__init__()
1526+
self.dense = TestDenseArch(num_float_features, device)
1527+
1528+
self.ebc: EmbeddingBagCollection = EmbeddingBagCollection(
1529+
tables=tables,
1530+
device=device,
1531+
)
1532+
self.weighted_ebc = EmbeddingBagCollection(
1533+
tables=weighted_tables,
1534+
is_weighted=True,
1535+
device=device,
1536+
)
1537+
self.preproc_nonweighted = TestPreprocNonWeighted()
1538+
self.preproc_weighted = TestPreprocWeighted()
1539+
self._preproc_module = preproc_module
1540+
self._run_preproc_inline = run_preproc_inline
1541+
1542+
def forward(
1543+
self,
1544+
input: ModelInput,
1545+
) -> Tuple[torch.Tensor, torch.Tensor]:
1546+
"""
1547+
Runs preprco for EBC and weighted EBC, optionally runs preproc for input
1548+
1549+
Args:
1550+
input
1551+
Returns:
1552+
Tuple[torch.Tensor, torch.Tensor]
1553+
"""
1554+
modified_input = input
1555+
1556+
if self._preproc_module is not None:
1557+
modified_input = self._preproc_module(modified_input)
1558+
elif self._run_preproc_inline:
1559+
modified_input.idlist_features = KeyedJaggedTensor.from_lengths_sync(
1560+
modified_input.idlist_features.keys(),
1561+
modified_input.idlist_features.values(),
1562+
modified_input.idlist_features.lengths(),
1563+
)
1564+
1565+
modified_idlist_features = self.preproc_nonweighted(
1566+
modified_input.idlist_features
1567+
)
1568+
modified_idscore_features = self.preproc_weighted(
1569+
modified_input.idscore_features
1570+
)
1571+
ebc_out = self.ebc(modified_idlist_features[0])
1572+
weighted_ebc_out = self.weighted_ebc(modified_idscore_features[0])
1573+
1574+
pred = torch.cat([ebc_out.values(), weighted_ebc_out.values()], dim=1)
1575+
return pred.sum(), pred
1576+
1577+
1578+
class TestNegSamplingModule(torch.nn.Module):
1579+
"""
1580+
Basic module to simulate feature augmentation preproc (e.g. neg sampling) for testing
1581+
1582+
Args:
1583+
extra_input
1584+
has_params
1585+
1586+
Example:
1587+
>>> preproc = TestNegSamplingModule(extra_input)
1588+
>>> out = preproc(in)
1589+
1590+
Returns:
1591+
ModelInput
1592+
"""
1593+
1594+
def __init__(
1595+
self,
1596+
extra_input: ModelInput,
1597+
has_params: bool = False,
1598+
) -> None:
1599+
super().__init__()
1600+
self._extra_input = extra_input
1601+
if has_params:
1602+
self._linear: nn.Module = nn.Linear(30, 30)
1603+
1604+
def forward(self, input: ModelInput) -> ModelInput:
1605+
"""
1606+
Appends extra features to model input
1607+
1608+
Args:
1609+
input
1610+
Returns:
1611+
ModelInput
1612+
"""
1613+
1614+
# merge extra input
1615+
modified_input = copy.deepcopy(input)
1616+
1617+
# dim=0 (batch dimensions) increases by self._extra_input.float_features.shape[0]
1618+
modified_input.float_features = torch.concat(
1619+
(modified_input.float_features, self._extra_input.float_features), dim=0
1620+
)
1621+
1622+
# stride will be same but features will be joined
1623+
modified_input.idlist_features = KeyedJaggedTensor.concat(
1624+
[modified_input.idlist_features, self._extra_input.idlist_features]
1625+
)
1626+
if self._extra_input.idscore_features is not None:
1627+
# stride will be smae but features will be joined
1628+
modified_input.idscore_features = KeyedJaggedTensor.concat(
1629+
# pyre-ignore
1630+
[modified_input.idscore_features, self._extra_input.idscore_features]
1631+
)
1632+
1633+
# dim=0 (batch dimensions) increases by self._extra_input.input_label.shape[0]
1634+
modified_input.label = torch.concat(
1635+
(modified_input.label, self._extra_input.label), dim=0
1636+
)
1637+
1638+
return modified_input
1639+
1640+
1641+
class TestPositionWeightedPreprocModule(torch.nn.Module):
1642+
"""
1643+
Basic module for testing
1644+
1645+
Args: None
1646+
Example:
1647+
>>> preproc = TestPositionWeightedPreprocModule(max_feature_lengths, device)
1648+
>>> out = preproc(in)
1649+
Returns:
1650+
ModelInput
1651+
"""
1652+
1653+
def __init__(
1654+
self, max_feature_lengths: Dict[str, int], device: torch.device
1655+
) -> None:
1656+
super().__init__()
1657+
self.fp_proc = PositionWeightedProcessor(
1658+
max_feature_lengths=max_feature_lengths,
1659+
device=device,
1660+
)
1661+
1662+
def forward(self, input: ModelInput) -> ModelInput:
1663+
"""
1664+
Runs PositionWeightedProcessor
1665+
1666+
Args:
1667+
input
1668+
Returns:
1669+
ModelInput
1670+
"""
1671+
modified_input = copy.deepcopy(input)
1672+
modified_input.idlist_features = self.fp_proc(modified_input.idlist_features)
1673+
return modified_input

0 commit comments

Comments
 (0)