Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
598 changes: 366 additions & 232 deletions examples/09_sasrec_example.ipynb

Large diffs are not rendered by default.

446 changes: 266 additions & 180 deletions examples/15_twotower_example.ipynb

Large diffs are not rendered by default.

17 changes: 13 additions & 4 deletions replay/nn/lightning/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(
self._lr_scheduler_factory = lr_scheduler_factory
self.candidates_to_score = None

def forward(self, batch: dict) -> Union[TrainOutput, InferenceOutput]:
def forward(self, batch: dict, return_info: bool = False) -> Union[TrainOutput, InferenceOutput]:
"""
Implementation of the forward function.

Expand All @@ -57,12 +57,21 @@ def forward(self, batch: dict) -> Union[TrainOutput, InferenceOutput]:
batch["candidates_to_score"] = self.candidates_to_score
# select only args for model.forward
modified_batch = {k: v for k, v in batch.items() if k in inspect.signature(self.model.forward).parameters}
return self.model(**modified_batch)
return self.model(**modified_batch, return_info=return_info)

def training_step(self, batch: dict) -> torch.Tensor:
model_output: TrainOutput = self(batch)
loss = model_output["loss"]
model_output: TrainOutput = self(batch, return_info=True)
loss, info = model_output["loss"], model_output.get("info", None)
lr = self.optimizers().param_groups[0]["lr"] # Get current learning rate
if info is not None:
assert isinstance(info, dict)
self.log_dict(
dictionary=info,
on_step=True,
on_epoch=True,
prog_bar=True,
sync_dist=True,
)
self.log("learning_rate", lr, on_step=True, on_epoch=True, prog_bar=True)
self.log(
"train_loss",
Expand Down
6 changes: 5 additions & 1 deletion replay/nn/loss/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .base import LossProto
from .base import LossInfo, LossOutput, LossProto
from .bce import BCE, BCESampled
from .ce import CE, CESampled, CESampledWeighted, CEWeighted
from .composed import ComposedLoss
from .login_ce import LogInCE, LogInCESampled
from .logout_ce import LogOutCE, LogOutCEWeighted

Expand All @@ -13,10 +14,13 @@
"CESampled",
"CESampledWeighted",
"CEWeighted",
"ComposedLoss",
"LogInCE",
"LogInCESampled",
"LogOutCE",
"LogOutCESampled",
"LogOutCEWeighted",
"LossInfo",
"LossOutput",
"LossProto",
]
26 changes: 23 additions & 3 deletions replay/nn/loss/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,23 @@

from replay.data.nn import TensorMap

LossInfo = dict[str, torch.Tensor | float]
LossOutput = tuple[torch.Tensor, None] | tuple[torch.Tensor, LossInfo]
LogitsCallback = Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]


class LossProto(Protocol):
"""Class-protocol for working with losses inside models"""

loss_name: str

@property
def logits_callback(
self,
) -> Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]: ...
) -> LogitsCallback: ...

@logits_callback.setter
def logits_callback(self, func: Optional[Callable]) -> None: ...
def logits_callback(self, func: LogitsCallback) -> None: ...

def forward(
self,
Expand All @@ -24,7 +30,19 @@ def forward(
negative_labels: torch.LongTensor,
padding_mask: torch.BoolTensor,
target_padding_mask: torch.BoolTensor,
) -> torch.Tensor: ...
return_info: bool = False,
) -> LossOutput: ...

def __call__(
self,
model_embeddings: torch.Tensor,
feature_tensors: TensorMap,
positive_labels: torch.LongTensor,
negative_labels: torch.LongTensor,
padding_mask: torch.BoolTensor,
target_padding_mask: torch.BoolTensor,
return_info: bool = False,
) -> LossOutput: ...


class SampledLossOutput(TypedDict):
Expand All @@ -39,6 +57,8 @@ class SampledLossOutput(TypedDict):
class SampledLossBase(torch.nn.Module):
"""The base class for calculating sampled losses"""

_logits_callback: LogitsCallback | None

@property
def logits_callback(
self,
Expand Down
40 changes: 25 additions & 15 deletions replay/nn/loss/bce.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from typing import Callable, Optional

import torch

from replay.data.nn import TensorMap

from .base import SampledLossBase, mask_negative_logits
from .base import LogitsCallback, LossOutput, SampledLossBase, mask_negative_logits


class BCE(torch.nn.Module):
Expand All @@ -16,19 +14,20 @@ class BCE(torch.nn.Module):
(there are several labels for each position in the sequence).
"""

def __init__(self, **kwargs):
def __init__(self, loss_name: str = "BCELoss", **kwargs):
"""
To calculate the loss, ``torch.nn.BCEWithLogitsLoss`` is used with the parameter ``reduction="sum"``.
You can pass all other parameters for initializing the object via kwargs.
"""
super().__init__()
self._loss = torch.nn.BCEWithLogitsLoss(reduction="sum", **kwargs)
self._logits_callback = None
self._logits_callback: LogitsCallback | None = None
self.loss_name: str = loss_name

@property
def logits_callback(
self,
) -> Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]:
) -> LogitsCallback:
"""
Property for calling a function for the logits computation.\n

Expand All @@ -46,7 +45,7 @@ def logits_callback(
return self._logits_callback

@logits_callback.setter
def logits_callback(self, func: Optional[Callable]) -> None:
def logits_callback(self, func: LogitsCallback) -> None:
self._logits_callback = func

def forward(
Expand All @@ -57,7 +56,8 @@ def forward(
negative_labels: torch.LongTensor, # noqa: ARG002
padding_mask: torch.BoolTensor, # noqa: ARG002
target_padding_mask: torch.BoolTensor,
) -> torch.Tensor:
return_info: bool = False,
) -> LossOutput:
"""
forward(model_embeddings, positive_labels, target_padding_mask)
:param model_embeddings: model output of shape ``(batch_size, sequence_length, embedding_dim)``.
Expand Down Expand Up @@ -92,7 +92,11 @@ def forward(
)

loss = self._loss(logits, bce_labels) / logits.size(0)
return loss

if return_info:
return (loss, {self.loss_name: loss.detach()})
else:
return (loss, None)


class BCESampled(SampledLossBase):
Expand All @@ -109,7 +113,8 @@ def __init__(
log_epsilon: float = 1e-6,
clamp_border: float = 100.0,
negative_labels_ignore_index: int = -100,
):
loss_name: str = "BCESampledLoss",
) -> None:
"""
:param log_epsilon: correction to avoid zero in the logarithm during loss calculating.
Default: ``1e-6``.
Expand All @@ -125,12 +130,13 @@ def __init__(
self.log_epsilon = log_epsilon
self.clamp_border = clamp_border
self.negative_labels_ignore_index = negative_labels_ignore_index
self._logits_callback = None
self._logits_callback: LogitsCallback | None = None
self.loss_name: str = loss_name

@property
def logits_callback(
self,
) -> Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]:
) -> LogitsCallback:
"""
Property for calling a function for the logits computation.\n

Expand All @@ -148,7 +154,7 @@ def logits_callback(
return self._logits_callback

@logits_callback.setter
def logits_callback(self, func: Optional[Callable]) -> None:
def logits_callback(self, func: LogitsCallback) -> None:
self._logits_callback = func

def forward(
Expand All @@ -159,7 +165,8 @@ def forward(
negative_labels: torch.LongTensor,
padding_mask: torch.BoolTensor, # noqa: ARG002
target_padding_mask: torch.BoolTensor,
) -> torch.Tensor:
return_info: bool = False,
) -> LossOutput:
"""
forward(model_embeddings, positive_labels, negative_labels, target_padding_mask)

Expand Down Expand Up @@ -213,4 +220,7 @@ def forward(
loss = -(positive_loss + negative_loss)
loss /= positive_logits.size(0)

return loss
if return_info:
return (loss, {self.loss_name: loss.detach()})
else:
return (loss, None)
Loading