From bcf3c8b37db9b5074dc24706877f21c73ebc564e Mon Sep 17 00:00:00 2001 From: Daria Tikhonovich Date: Tue, 3 Jun 2025 13:58:38 +0300 Subject: [PATCH 1/3] context net draft --- rectools/dataset/dataset.py | 54 +++++++++++- rectools/models/nn/transformers/base.py | 41 ++++++++- .../models/nn/transformers/context_net.py | 79 +++++++++++++++++ .../models/nn/transformers/data_preparator.py | 88 ++++++++++++++++--- rectools/models/nn/transformers/sasrec.py | 58 +++++++++--- .../models/nn/transformers/torch_backbone.py | 6 ++ 6 files changed, 294 insertions(+), 32 deletions(-) create mode 100644 rectools/models/nn/transformers/context_net.py diff --git a/rectools/dataset/dataset.py b/rectools/dataset/dataset.py index 6d7a7d52..68630666 100644 --- a/rectools/dataset/dataset.py +++ b/rectools/dataset/dataset.py @@ -78,6 +78,14 @@ class SparseFeaturesSchema(BaseFeaturesSchema): cat_n_stored_values: int +class InteractionsFeaturesSchema(BaseConfig): + """Interactions features schema.""" + + cat_feature_names: tp.List[str] + cat_feature_names_w_values: tp.List[tp.Tuple[str, str]] + direct_feature_names: tp.List[str] + + FeaturesSchema = tp.Union[DenseFeaturesSchema, SparseFeaturesSchema] @@ -102,6 +110,7 @@ class DatasetSchema(BaseConfig): n_interactions: int users: EntitySchema items: EntitySchema + interactions: tp.Optional[InteractionsFeaturesSchema] = None @attr.s(slots=True, frozen=True) @@ -135,6 +144,7 @@ class Dataset: interactions: Interactions = attr.ib() user_features: tp.Optional[Features] = attr.ib(default=None) item_features: tp.Optional[Features] = attr.ib(default=None) + interactions_schema: tp.Optional[InteractionsFeaturesSchema] = attr.ib(default=None) @staticmethod def _get_feature_schema(features: tp.Optional[Features]) -> tp.Optional[FeaturesSchema]: @@ -170,6 +180,7 @@ def get_schema(self) -> DatasetSchemaDict: n_interactions=self.interactions.df.shape[0], users=user_schema, items=item_schema, + interactions=self.interactions_schema, ) return schema.model_dump(mode="json") @@ -206,7 +217,7 @@ def get_hot_item_features(self) -> tp.Optional[Features]: return self.item_features.take(range(self.n_hot_items)) @classmethod - def construct( + def construct( # pylint: disable=too-many-locals cls, interactions_df: pd.DataFrame, user_features_df: tp.Optional[pd.DataFrame] = None, @@ -216,6 +227,8 @@ def construct( cat_item_features: tp.Iterable[str] = (), make_dense_item_features: bool = False, keep_extra_cols: bool = False, + interactions_cat_features: tp.Iterable[str] = (), + interactions_direct_features: tp.Iterable[str] = (), ) -> "Dataset": """Class method for convenient `Dataset` creation. @@ -249,6 +262,10 @@ def construct( - if ``True``, `DenseFeatures.from_dataframe` method will be used. keep_extra_cols: bool, default ``False`` Flag to keep all columns from interactions besides the default ones. + interactions_cat_features : tp.Iterable[str], default ``()`` + List of categorical feature names in interactions dataframe. + interactions_direct_features : tp.Iterable[str], default ``()`` + List of direct (non-categorical) feature names in interactions dataframe. Returns ------- @@ -258,6 +275,32 @@ def construct( for col in (Columns.User, Columns.Item): if col not in interactions_df: raise KeyError(f"Column '{col}' must be present in `interactions_df`") + + # Validate interactions features + cat_features = set(interactions_cat_features) + direct_features = set(interactions_direct_features) + required_columns = cat_features | direct_features + actual_columns = set(interactions_df.columns) + if not actual_columns >= required_columns: + raise KeyError(f"Missed columns {required_columns - actual_columns}") + + # Create interactions feature schema + cat_feature_names_w_values = [] + for cat_feature in cat_features: + values = interactions_df[cat_feature].unique() # TODO: decide NaN values + for value in values: + cat_feature_names_w_values.append((cat_feature, value)) + + interactions_schema = ( + InteractionsFeaturesSchema( + cat_feature_names=list(cat_features), + direct_feature_names=list(direct_features), + cat_feature_names_w_values=cat_feature_names_w_values, + ) + if cat_features or direct_features + else None + ) + user_id_map = IdMap.from_values(interactions_df[Columns.User].values) item_id_map = IdMap.from_values(interactions_df[Columns.Item].values) interactions = Interactions.from_raw(interactions_df, user_id_map, item_id_map, keep_extra_cols) @@ -278,7 +321,14 @@ def construct( Columns.Item, "item", ) - return cls(user_id_map, item_id_map, interactions, user_features, item_features) + return cls( + user_id_map=user_id_map, + item_id_map=item_id_map, + interactions=interactions, + user_features=user_features, + item_features=item_features, + interactions_schema=interactions_schema, + ) @staticmethod def _make_features( diff --git a/rectools/models/nn/transformers/base.py b/rectools/models/nn/transformers/base.py index 14982cb9..3164d26d 100644 --- a/rectools/models/nn/transformers/base.py +++ b/rectools/models/nn/transformers/base.py @@ -38,6 +38,7 @@ ItemNetConstructorBase, SumOfEmbeddingsConstructor, ) +from .context_net import CatFeaturesContextNet, ContextNetBase from .data_preparator import InitKwargs, TransformerDataPreparatorBase from .lightning import TransformerLightningModule, TransformerLightningModuleBase from .negative_sampler import CatalogUniformSampler, TransformerNegativeSamplerBase @@ -117,6 +118,16 @@ def _serialize_type_sequence(obj: tp.Sequence[tp.Type]) -> tp.Tuple[str, ...]: ), ] +ContextNetType = tpe.Annotated[ + tp.Type[ContextNetBase], + BeforeValidator(_get_class_obj), + PlainSerializer( + func=get_class_or_function_full_path, + return_type=str, + when_used="json", + ), +] + TransformerDataPreparatorType = tpe.Annotated[ tp.Type[TransformerDataPreparatorBase], BeforeValidator(_get_class_obj), @@ -216,6 +227,7 @@ class TransformerModelConfig(ModelConfig): negative_sampler_type: TransformerNegativeSamplerType = CatalogUniformSampler similarity_module_type: SimilarityModuleType = DistanceSimilarityModule backbone_type: TransformerBackboneType = TransformerTorchBackbone + context_net_type: ContextNetType = CatFeaturesContextNet get_val_mask_func: tp.Optional[ValMaskCallableSerialized] = None get_trainer_func: tp.Optional[TrainerCallableSerialized] = None get_val_mask_func_kwargs: tp.Optional[InitKwargs] = None @@ -228,6 +240,7 @@ class TransformerModelConfig(ModelConfig): negative_sampler_kwargs: tp.Optional[InitKwargs] = None similarity_module_kwargs: tp.Optional[InitKwargs] = None backbone_kwargs: tp.Optional[InitKwargs] = None + context_net_kwargs: tp.Optional[InitKwargs] = None TransformerModelConfig_T = tp.TypeVar("TransformerModelConfig_T", bound=TransformerModelConfig) @@ -278,6 +291,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals negative_sampler_type: tp.Type[TransformerNegativeSamplerBase] = CatalogUniformSampler, similarity_module_type: tp.Type[SimilarityModuleBase] = DistanceSimilarityModule, backbone_type: tp.Type[TransformerBackboneBase] = TransformerTorchBackbone, + context_net_type: tp.Type[ContextNetBase] = CatFeaturesContextNet, get_val_mask_func: tp.Optional[ValMaskCallable] = None, get_trainer_func: tp.Optional[TrainerCallable] = None, get_val_mask_func_kwargs: tp.Optional[InitKwargs] = None, @@ -290,6 +304,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals negative_sampler_kwargs: tp.Optional[InitKwargs] = None, similarity_module_kwargs: tp.Optional[InitKwargs] = None, backbone_kwargs: tp.Optional[InitKwargs] = None, + context_net_kwargs: tp.Optional[InitKwargs] = None, **kwargs: tp.Any, ) -> None: super().__init__(verbose=verbose) @@ -321,6 +336,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals self.lightning_module_type = lightning_module_type self.negative_sampler_type = negative_sampler_type self.backbone_type = backbone_type + self.context_net_type = context_net_type self.get_val_mask_func = get_val_mask_func self.get_trainer_func = get_trainer_func self.get_val_mask_func_kwargs = get_val_mask_func_kwargs @@ -333,7 +349,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals self.negative_sampler_kwargs = negative_sampler_kwargs self.similarity_module_kwargs = similarity_module_kwargs self.backbone_kwargs = backbone_kwargs - + self.context_net_kwargs = context_net_kwargs self._init_data_preparator() self._init_trainer() @@ -392,6 +408,16 @@ def _construct_item_net(self, dataset: Dataset) -> ItemNetBase: **self._get_kwargs(self.item_net_constructor_kwargs), ) + def _construct_context_net(self, dataset_schema: DatasetSchema) -> tp.Optional[ContextNetBase]: + if dataset_schema.interactions is None: + return None + return self.context_net_type.from_dataset_schema( + dataset_schema, + self.n_factors, + self.dropout_rate, + **self._get_kwargs(self.context_net_kwargs), + ) + def _construct_item_net_from_dataset_schema(self, dataset_schema: DatasetSchema) -> ItemNetBase: return self.item_net_constructor_type.from_dataset_schema( dataset_schema, @@ -421,7 +447,9 @@ def _init_transformer_layers(self) -> TransformerLayersBase: def _init_similarity_module(self) -> SimilarityModuleBase: return self.similarity_module_type(**self._get_kwargs(self.similarity_module_kwargs)) - def _init_torch_model(self, item_model: ItemNetBase) -> TransformerBackboneBase: + def _init_torch_model( + self, item_model: ItemNetBase, context_net: tp.Optional[ContextNetBase] + ) -> TransformerBackboneBase: pos_encoding_layer = self._init_pos_encoding_layer() transformer_layers = self._init_transformer_layers() similarity_module = self._init_similarity_module() @@ -429,6 +457,7 @@ def _init_torch_model(self, item_model: ItemNetBase) -> TransformerBackboneBase: n_heads=self.n_heads, dropout_rate=self.dropout_rate, item_model=item_model, + context_net=context_net, pos_encoding_layer=pos_encoding_layer, transformer_layers=transformer_layers, similarity_module=similarity_module, @@ -464,7 +493,10 @@ def _init_lightning_model( def _build_model_from_dataset(self, dataset: Dataset) -> None: self.data_preparator.process_dataset_train(dataset) item_model = self._construct_item_net(self.data_preparator.train_dataset) - torch_model = self._init_torch_model(item_model) + context_net = self._construct_context_net( + DatasetSchema.model_validate(self.data_preparator.train_dataset.get_schema()) + ) + torch_model = self._init_torch_model(item_model, context_net) dataset_schema = self.data_preparator.train_dataset.get_schema() item_external_ids = self.data_preparator.train_dataset.item_id_map.external_ids @@ -589,7 +621,8 @@ def _model_from_checkpoint(cls, checkpoint: tp.Dict[str, tp.Any]) -> tpe.Self: # Init and update torch model and lightning model item_model = loaded._construct_item_net_from_dataset_schema(dataset_schema) - torch_model = loaded._init_torch_model(item_model) + context_net = loaded._construct_context_net(dataset_schema) + torch_model = loaded._init_torch_model(item_model, context_net) loaded._init_lightning_model( torch_model=torch_model, dataset_schema=dataset_schema, diff --git a/rectools/models/nn/transformers/context_net.py b/rectools/models/nn/transformers/context_net.py new file mode 100644 index 00000000..74ec01d6 --- /dev/null +++ b/rectools/models/nn/transformers/context_net.py @@ -0,0 +1,79 @@ +import typing as tp +import warnings + +import torch +import typing_extensions as tpe +from torch import nn + +from rectools.dataset.dataset import DatasetSchema + +# TODO: support non-string values in feature names/values + + +class ContextNetBase(torch.nn.Module): + """TODO.""" + + def __init__(self, n_factors: int, dropout_rate: float, **kwargs: tp.Any): + super().__init__() + + def forward(self, seqs: torch.Tensor, batch: tp.Dict[str, torch.Tensor]) -> torch.Tensor: + """TODO.""" + raise NotImplementedError + + @classmethod + def from_dataset_schema( + cls, dataset_schema: DatasetSchema, *args: tp.Any, **kwargs: tp.Any + ) -> tp.Optional[tpe.Self]: + """Construct ItemNet from Dataset schema.""" + raise NotImplementedError() + + @property + def out_dim(self) -> int: + """Return item embedding output dimension.""" + raise NotImplementedError() + + +class CatFeaturesContextNet(ContextNetBase): + """TODO.""" + + def __init__(self, n_factors: int, dropout_rate: float, n_cat_feature_values: int, **kwargs: tp.Any) -> None: + super().__init__(n_factors, dropout_rate, **kwargs) + print(n_cat_feature_values) + self.embedding_bag = nn.EmbeddingBag(num_embeddings=n_cat_feature_values, embedding_dim=n_factors, mode="sum") + self.dropout = nn.Dropout(dropout_rate) + + @classmethod + def from_dataset_schema( # TODO: decide about target aware schema + cls, dataset_schema: DatasetSchema, n_factors: int, dropout_rate: float, **kwargs: tp.Any + ) -> tp.Optional[tpe.Self]: + """TODO.""" + if dataset_schema.interactions is None: + warnings.warn("No interactions schema found in dataset schema, context net will not be constructed") + return None + if dataset_schema.interactions.direct_feature_names: + warnings.warn("Direct features are not supported in context net") + if len(dataset_schema.interactions.cat_feature_names_w_values) == 0: + warnings.warn("No categorical features found in dataset schema, context net will not be constructed") + return None + n_cat_feature_values = len(dataset_schema.interactions.cat_feature_names_w_values) + return cls(n_factors=n_factors, dropout_rate=dropout_rate, n_cat_feature_values=n_cat_feature_values) + + def forward(self, seqs: torch.Tensor, batch: tp.Dict[str, torch.Tensor]) -> torch.Tensor: + """TODO.""" + # TODO: check correctness and remove offsets from batch + b, l, f = seqs.shape + offsets = batch["context_cat_offsets"].view(-1) + offsets = torch.cat([torch.zeros(1, dtype=offsets.dtype, device=offsets.device), offsets]) + offsets = offsets.cumsum(dim=0)[:-1] + + inputs = batch["context_cat_inputs"] + new_inputs = inputs.view(b * l, -1) + context_embs = self.embedding_bag(input=new_inputs) + context_embs = self.dropout(context_embs) + context_embs = context_embs.view(b, l, f) + return seqs + context_embs + + @property + def out_dim(self) -> int: + """Return categorical item embedding output dimension.""" + return self.embedding_bag.embedding_dim diff --git a/rectools/models/nn/transformers/data_preparator.py b/rectools/models/nn/transformers/data_preparator.py index b13ec87d..4c7de9a3 100644 --- a/rectools/models/nn/transformers/data_preparator.py +++ b/rectools/models/nn/transformers/data_preparator.py @@ -24,7 +24,7 @@ from torch.utils.data import Dataset as TorchDataset from rectools import Columns, ExternalIds -from rectools.dataset import Dataset, Interactions +from rectools.dataset.dataset import Dataset, Interactions, InteractionsFeaturesSchema from rectools.dataset.features import DenseFeatures, Features, SparseFeatures from rectools.dataset.identifiers import IdMap @@ -32,6 +32,8 @@ from .negative_sampler import TransformerNegativeSamplerBase InitKwargs = tp.Dict[str, tp.Any] +PayloadsSpec = tp.Dict[str, tp.List[tp.Any]] +BatchElement = tp.Tuple[tp.List[int], tp.List[float], PayloadsSpec] class SequenceDataset(TorchDataset): @@ -46,23 +48,33 @@ class SequenceDataset(TorchDataset): Weight of each interaction from the session. """ - def __init__(self, sessions: tp.List[tp.List[int]], weights: tp.List[tp.List[float]]): + def __init__( + self, + sessions: tp.List[tp.List[int]], + weights: tp.List[tp.List[float]], + payloads: tp.Optional[PayloadsSpec] = None, + ): self.sessions = sessions self.weights = weights + self.payloads = payloads if payloads is not None else {} def __len__(self) -> int: return len(self.sessions) - def __getitem__(self, index: int) -> tp.Tuple[tp.List[int], tp.List[float]]: + def __getitem__(self, index: int) -> BatchElement: session = self.sessions[index] # [session_len] weights = self.weights[index] # [session_len] - return session, weights + if self.payloads: + payloads = {feature_name: features[index] for feature_name, features in self.payloads.items()} + return session, weights, payloads + return session, weights, {} @classmethod def from_interactions( cls, interactions: pd.DataFrame, sort_users: bool = False, + interactions_features_schema: tp.Optional[InteractionsFeaturesSchema] = None, ) -> "SequenceDataset": """ Group interactions by user. @@ -73,17 +85,42 @@ def from_interactions( interactions : pd.DataFrame User-item interactions. """ + cols_to_agg = [Columns.Item, Columns.Weight] + extra_cols = [] + if interactions_features_schema is not None: + if len(interactions_features_schema.cat_feature_names_w_values) > 0: + # Map categorical features to their inputs for embedding bag + mappings: tp.Dict[str, tp.Dict[str, int]] = {} # TODO: decide feature names + features_mapped = pd.DataFrame() + for inputs, (feature_name, feature_value) in enumerate( + interactions_features_schema.cat_feature_names_w_values + ): + if feature_name not in mappings: + mappings[feature_name] = {} + mappings[feature_name][feature_value] = inputs + for feature_name, feature_mappings in mappings.items(): + features_mapped[feature_name] = interactions[feature_name].map(feature_mappings) + + # Combine all feature values into a single "inputs" column + interactions["context_cat_inputs"] = features_mapped[list(mappings.keys())].values # .tolist() + interactions["context_cat_offsets"] = len(mappings.keys()) + extra_cols.extend(["context_cat_inputs", "context_cat_offsets"]) + sessions = ( interactions.sort_values(Columns.Datetime, kind="stable") - .groupby(Columns.User, sort=sort_users)[[Columns.Item, Columns.Weight]] + .groupby(Columns.User, sort=sort_users)[cols_to_agg + extra_cols] .agg(list) ) - sessions, weights = ( + + session_items, weights = ( sessions[Columns.Item].to_list(), sessions[Columns.Weight].to_list(), ) - return cls(sessions=sessions, weights=weights) + payloads = None + if extra_cols: + payloads = {col: sessions[col].to_list() for col in extra_cols} + return cls(sessions=session_items, weights=weights, payloads=payloads) class TransformerDataPreparatorBase: # pylint: disable=too-many-instance-attributes @@ -232,7 +269,13 @@ def process_dataset_train(self, dataset: Dataset) -> None: # Prepare train dataset # User features are dropped for now because model doesn't support them final_interactions = Interactions.from_raw(interactions, user_id_map, item_id_map, keep_extra_cols=True) - self.train_dataset = Dataset(user_id_map, item_id_map, final_interactions, item_features=item_features) + self.train_dataset = Dataset( + user_id_map, + item_id_map, + final_interactions, + item_features=item_features, + interactions_schema=dataset.interactions_schema, + ) self.item_id_map = self.train_dataset.item_id_map self._init_extra_token_ids() @@ -261,7 +304,10 @@ def get_dataloader_train(self) -> DataLoader: DataLoader Train dataloader. """ - sequence_dataset = SequenceDataset.from_interactions(self.train_dataset.interactions.df) + sequence_dataset = SequenceDataset.from_interactions( + self.train_dataset.interactions.df, + interactions_features_schema=self.train_dataset.interactions_schema, + ) train_dataloader = DataLoader( sequence_dataset, collate_fn=self._collate_fn_train, @@ -283,7 +329,9 @@ def get_dataloader_val(self) -> tp.Optional[DataLoader]: if self.val_interactions is None: return None - sequence_dataset = SequenceDataset.from_interactions(self.val_interactions) + sequence_dataset = SequenceDataset.from_interactions( + self.val_interactions, interactions_features_schema=self.train_dataset.interactions_schema + ) val_dataloader = DataLoader( sequence_dataset, collate_fn=self._collate_fn_val, @@ -306,7 +354,11 @@ def get_dataloader_recommend(self, dataset: Dataset, batch_size: int) -> DataLoa # User ids here are internal user ids in dataset.interactions.df that was prepared for recommendations. # Sorting sessions by user ids will ensure that these ids will also be correct indexes in user embeddings matrix # that will be returned by the net. - sequence_dataset = SequenceDataset.from_interactions(interactions=dataset.interactions.df, sort_users=True) + sequence_dataset = SequenceDataset.from_interactions( + interactions=dataset.interactions.df, + sort_users=True, + interactions_features_schema=dataset.interactions_schema, + ) recommend_dataloader = DataLoader( sequence_dataset, batch_size=batch_size, @@ -360,7 +412,12 @@ def transform_dataset_u2i(self, dataset: Dataset, users: ExternalIds) -> Dataset explanation = f"""{n_filtered} target users were considered cold because of missing known items""" warnings.warn(explanation) filtered_interactions = Interactions.from_raw(interactions, rec_user_id_map, self.item_id_map) - filtered_dataset = Dataset(rec_user_id_map, self.item_id_map, filtered_interactions) + filtered_dataset = Dataset( + rec_user_id_map, + self.item_id_map, + filtered_interactions, + interactions_schema=dataset.interactions_schema, + ) return filtered_dataset def transform_dataset_i2i(self, dataset: Dataset) -> Dataset: @@ -384,7 +441,12 @@ def transform_dataset_i2i(self, dataset: Dataset) -> Dataset: interactions = dataset.get_raw_interactions() interactions = interactions[interactions[Columns.Item].isin(self.get_known_item_ids())] filtered_interactions = Interactions.from_raw(interactions, dataset.user_id_map, self.item_id_map) - filtered_dataset = Dataset(dataset.user_id_map, self.item_id_map, filtered_interactions) + filtered_dataset = Dataset( + dataset.user_id_map, + self.item_id_map, + filtered_interactions, + interactions_schema=dataset.interactions_schema, + ) return filtered_dataset def _collate_fn_train( diff --git a/rectools/models/nn/transformers/sasrec.py b/rectools/models/nn/transformers/sasrec.py index b8350f72..1d1354c7 100644 --- a/rectools/models/nn/transformers/sasrec.py +++ b/rectools/models/nn/transformers/sasrec.py @@ -36,7 +36,8 @@ TransformerModelConfig, ValMaskCallable, ) -from .data_preparator import InitKwargs, TransformerDataPreparatorBase +from .context_net import CatFeaturesContextNet, ContextNetBase +from .data_preparator import BatchElement, InitKwargs, TransformerDataPreparatorBase from .negative_sampler import CatalogUniformSampler, TransformerNegativeSamplerBase from .net_blocks import ( LearnableInversePositionalEncoding, @@ -77,10 +78,11 @@ class SASRecDataPreparator(TransformerDataPreparatorBase): """ train_session_max_len_addition: int = 1 + non_pad_payload_keys: List[str] = [] # ["context_cat_offsets"] def _collate_fn_train( self, - batch: List[Tuple[List[int], List[float]]], + batch: List[BatchElement], ) -> Dict[str, torch.Tensor]: """ Truncate each session from right to keep `session_max_len` items. @@ -91,12 +93,25 @@ def _collate_fn_train( x = np.zeros((batch_size, self.session_max_len)) y = np.zeros((batch_size, self.session_max_len)) yw = np.zeros((batch_size, self.session_max_len)) - for i, (ses, ses_weights) in enumerate(batch): + + payloads_keys = [key for key in batch[0][2].keys() if key not in self.non_pad_payload_keys] + train_payloads = np.zeros((len(payloads_keys), batch_size, self.session_max_len)) + + for i, (ses, ses_weights, payloads) in enumerate(batch): x[i, -len(ses) + 1 :] = ses[:-1] # ses: [session_len] -> x[i]: [session_max_len] y[i, -len(ses) + 1 :] = ses[1:] # ses: [session_len] -> y[i]: [session_max_len] yw[i, -len(ses) + 1 :] = ses_weights[1:] # ses_weights: [session_len] -> yw[i]: [session_max_len] - - batch_dict = {"x": torch.LongTensor(x), "y": torch.LongTensor(y), "yw": torch.FloatTensor(yw)} + for j, key in enumerate(payloads_keys): + train_payloads[j, i, -len(ses) + 1 :] = payloads[key][:-1] + batch_dict = { + "x": torch.LongTensor(x), + "y": torch.LongTensor(y), + "yw": torch.FloatTensor(yw), + } + payloads_dict = {key: torch.LongTensor(train_payloads[j]) for j, key in enumerate(payloads_keys)} + for key in self.non_pad_payload_keys: + payloads_dict[key] = torch.LongTensor(payloads[key]) + batch_dict.update(payloads_dict) if self.negative_sampler is not None: batch_dict["negatives"] = self.negative_sampler.get_negatives( batch_dict, lowest_id=self.n_item_extra_tokens, highest_id=self.item_id_map.size @@ -108,18 +123,24 @@ def _collate_fn_val(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[st x = np.zeros((batch_size, self.session_max_len)) y = np.zeros((batch_size, 1)) # Only leave-one-strategy is supported for losses yw = np.zeros((batch_size, 1)) # Only leave-one-strategy is supported for losses - for i, (ses, ses_weights) in enumerate(batch): - input_session = [ses[idx] for idx, weight in enumerate(ses_weights) if weight == 0] - # take only first target for leave-one-strategy - target_idx = [idx for idx, weight in enumerate(ses_weights) if weight != 0][0] + payloads_keys = batch[0][2].keys() + train_payloads = np.zeros((len(payloads_keys), batch_size, self.session_max_len)) + + for i, (ses, ses_weights, payloads) in enumerate(batch): + input_mask = ses_weights == 0 + input_session = ses[input_mask] # ses: [session_len] -> x[i]: [session_max_len] x[i, -len(input_session) :] = input_session[-self.session_max_len :] - y[i, -1:] = ses[target_idx] # y[i]: [1] - yw[i, -1:] = ses_weights[target_idx] # yw[i]: [1] + y[i, -1:] = ses[~input_mask][0] # y[i]: [1] take only first target for leave-one-strategy + yw[i, -1:] = ses_weights[~input_mask][0] # yw[i]: [1] + for j, key in enumerate(payloads_keys): + train_payloads[j, i, -len(input_session) :] = payloads[key][input_mask][-self.session_max_len :] batch_dict = {"x": torch.LongTensor(x), "y": torch.LongTensor(y), "yw": torch.FloatTensor(yw)} + payloads_dict = {key: torch.LongTensor(train_payloads[j]) for j, key in enumerate(payloads_keys)} + batch_dict.update(payloads_dict) if self.negative_sampler is not None: batch_dict["negatives"] = self.negative_sampler.get_negatives( batch_dict, lowest_id=self.n_item_extra_tokens, highest_id=self.item_id_map.size, session_len_limit=1 @@ -129,9 +150,16 @@ def _collate_fn_val(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[st def _collate_fn_recommend(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[str, torch.Tensor]: """Right truncation, left padding to session_max_len""" x = np.zeros((len(batch), self.session_max_len)) - for i, (ses, _) in enumerate(batch): + payloads_keys = batch[0][2].keys() + train_payloads = np.zeros((len(payloads_keys), len(batch), self.session_max_len)) + for i, (ses, _, payloads) in enumerate(batch): x[i, -len(ses) :] = ses[-self.session_max_len :] - return {"x": torch.LongTensor(x)} + for j, key in enumerate(payloads_keys): + train_payloads[j, i, -len(ses) :] = payloads[key][-self.session_max_len :] + batch_dict = {"x": torch.LongTensor(x)} + payloads_dict = {key: torch.LongTensor(train_payloads[j]) for j, key in enumerate(payloads_keys)} + batch_dict.update(payloads_dict) + return batch_dict class SASRecTransformerLayer(nn.Module): @@ -444,6 +472,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals negative_sampler_type: tp.Type[TransformerNegativeSamplerBase] = CatalogUniformSampler, similarity_module_type: tp.Type[SimilarityModuleBase] = DistanceSimilarityModule, backbone_type: tp.Type[TransformerBackboneBase] = TransformerTorchBackbone, + context_net_type: tp.Type[ContextNetBase] = CatFeaturesContextNet, get_val_mask_func: tp.Optional[ValMaskCallable] = None, get_trainer_func: tp.Optional[TrainerCallable] = None, get_val_mask_func_kwargs: tp.Optional[InitKwargs] = None, @@ -460,6 +489,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals negative_sampler_kwargs: tp.Optional[InitKwargs] = None, similarity_module_kwargs: tp.Optional[InitKwargs] = None, backbone_kwargs: tp.Optional[InitKwargs] = None, + context_net_kwargs: tp.Optional[InitKwargs] = None, ): super().__init__( transformer_layers_type=transformer_layers_type, @@ -493,6 +523,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals lightning_module_type=lightning_module_type, negative_sampler_type=negative_sampler_type, backbone_type=backbone_type, + context_net_type=context_net_type, get_val_mask_func=get_val_mask_func, get_trainer_func=get_trainer_func, get_val_mask_func_kwargs=get_val_mask_func_kwargs, @@ -505,4 +536,5 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals negative_sampler_kwargs=negative_sampler_kwargs, similarity_module_kwargs=similarity_module_kwargs, backbone_kwargs=backbone_kwargs, + context_net_kwargs=context_net_kwargs, ) diff --git a/rectools/models/nn/transformers/torch_backbone.py b/rectools/models/nn/transformers/torch_backbone.py index 6a78dc94..4302e094 100644 --- a/rectools/models/nn/transformers/torch_backbone.py +++ b/rectools/models/nn/transformers/torch_backbone.py @@ -34,6 +34,7 @@ def __init__( similarity_module: SimilarityModuleBase, use_causal_attn: bool = True, use_key_padding_mask: bool = False, + context_net: tp.Optional[torch.nn.Module] = None, **kwargs: tp.Any, ) -> None: """ @@ -70,6 +71,7 @@ def __init__( self.use_causal_attn = use_causal_attn self.use_key_padding_mask = use_key_padding_mask self.n_heads = n_heads + self.context_net = context_net def encode_sessions(self, batch: tp.Dict[str, torch.Tensor], item_embs: torch.Tensor) -> torch.Tensor: """ @@ -151,6 +153,7 @@ def __init__( similarity_module: SimilarityModuleBase, use_causal_attn: bool = True, use_key_padding_mask: bool = False, + context_net: tp.Optional[torch.nn.Module] = None, **kwargs: tp.Any, ) -> None: super().__init__( @@ -162,6 +165,7 @@ def __init__( similarity_module=similarity_module, use_causal_attn=use_causal_attn, use_key_padding_mask=use_key_padding_mask, + context_net=context_net, **kwargs, ) @@ -243,6 +247,8 @@ def encode_sessions(self, batch: tp.Dict[str, torch.Tensor], item_embs: torch.Te timeline_mask = (sessions != 0).unsqueeze(-1) # [batch_size, session_max_len, 1] seqs = item_embs[sessions] # [batch_size, session_max_len, n_factors] + if self.context_net is not None: + seqs = self.context_net(seqs, batch) seqs = self.pos_encoding_layer(seqs) seqs = self.emb_dropout(seqs) From 88597efc9bfa42c4fd803231ef62274925f50161 Mon Sep 17 00:00:00 2001 From: Daria Tikhonovich Date: Tue, 3 Jun 2025 19:29:46 +0300 Subject: [PATCH 2/3] cleaned code --- rectools/models/nn/item_net.py | 2 +- rectools/models/nn/transformers/bert4rec.py | 14 ++++----- .../models/nn/transformers/context_net.py | 26 +++++++++-------- .../models/nn/transformers/data_preparator.py | 13 +++++---- rectools/models/nn/transformers/sasrec.py | 29 ++++++++++++------- .../models/nn/transformers/torch_backbone.py | 2 +- tests/models/nn/transformers/test_bert4rec.py | 6 ++-- 7 files changed, 51 insertions(+), 41 deletions(-) diff --git a/rectools/models/nn/item_net.py b/rectools/models/nn/item_net.py index cf0be461..65c2f98f 100644 --- a/rectools/models/nn/item_net.py +++ b/rectools/models/nn/item_net.py @@ -486,4 +486,4 @@ def forward(self, items: torch.Tensor) -> torch.Tensor: @property def out_dim(self) -> int: """Return item net constructor output dimension.""" - return self.item_net_blocks[0].out_dim # type: ignore[return-value] + return self.item_net_blocks[0].out_dim diff --git a/rectools/models/nn/transformers/bert4rec.py b/rectools/models/nn/transformers/bert4rec.py index 8e31d6ff..108c1db8 100644 --- a/rectools/models/nn/transformers/bert4rec.py +++ b/rectools/models/nn/transformers/bert4rec.py @@ -36,7 +36,7 @@ ValMaskCallable, ) from .constants import MASKING_VALUE, PADDING_VALUE -from .data_preparator import InitKwargs, TransformerDataPreparatorBase +from .data_preparator import BatchElement, InitKwargs, TransformerDataPreparatorBase from .negative_sampler import CatalogUniformSampler, TransformerNegativeSamplerBase from .net_blocks import ( LearnableInversePositionalEncoding, @@ -128,7 +128,7 @@ def _mask_session( def _collate_fn_train( self, - batch: List[Tuple[List[int], List[float]]], + batch: List[BatchElement], ) -> Dict[str, torch.Tensor]: """ Mask session elements to receive `x`. @@ -141,7 +141,7 @@ def _collate_fn_train( x = np.zeros((batch_size, self.session_max_len)) y = np.zeros((batch_size, self.session_max_len)) yw = np.zeros((batch_size, self.session_max_len)) - for i, (ses, ses_weights) in enumerate(batch): + for i, (ses, ses_weights, _) in enumerate(batch): masked_session, target = self._mask_session(ses) x[i, -len(ses) :] = masked_session # ses: [session_len] -> x[i]: [session_max_len] y[i, -len(ses) :] = target # ses: [session_len] -> y[i]: [session_max_len] @@ -154,12 +154,12 @@ def _collate_fn_train( ) return batch_dict - def _collate_fn_val(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[str, torch.Tensor]: + def _collate_fn_val(self, batch: List[BatchElement]) -> Dict[str, torch.Tensor]: batch_size = len(batch) x = np.zeros((batch_size, self.session_max_len)) y = np.zeros((batch_size, 1)) # until only leave-one-strategy yw = np.zeros((batch_size, 1)) # until only leave-one-strategy - for i, (ses, ses_weights) in enumerate(batch): + for i, (ses, ses_weights, _) in enumerate(batch): input_session = [ses[idx] for idx, weight in enumerate(ses_weights) if weight == 0] session = input_session.copy() @@ -179,14 +179,14 @@ def _collate_fn_val(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[st ) return batch_dict - def _collate_fn_recommend(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[str, torch.Tensor]: + def _collate_fn_recommend(self, batch: List[BatchElement]) -> Dict[str, torch.Tensor]: """ Right truncation, left padding to `session_max_len` During inference model will use (`session_max_len` - 1) interactions and one extra "MASK" token will be added for making predictions. """ x = np.zeros((len(batch), self.session_max_len)) - for i, (ses, _) in enumerate(batch): + for i, (ses, _, _) in enumerate(batch): session = ses.copy() session = session + [self.extra_token_ids[MASKING_VALUE]] x[i, -len(ses) - 1 :] = session[-self.session_max_len :] diff --git a/rectools/models/nn/transformers/context_net.py b/rectools/models/nn/transformers/context_net.py index 74ec01d6..86b0159d 100644 --- a/rectools/models/nn/transformers/context_net.py +++ b/rectools/models/nn/transformers/context_net.py @@ -36,11 +36,19 @@ def out_dim(self) -> int: class CatFeaturesContextNet(ContextNetBase): """TODO.""" - def __init__(self, n_factors: int, dropout_rate: float, n_cat_feature_values: int, **kwargs: tp.Any) -> None: + def __init__( + self, + n_factors: int, + dropout_rate: float, + n_cat_feature_values: int, + batch_key: str = "context_cat_inputs", + **kwargs: tp.Any, + ) -> None: super().__init__(n_factors, dropout_rate, **kwargs) print(n_cat_feature_values) self.embedding_bag = nn.EmbeddingBag(num_embeddings=n_cat_feature_values, embedding_dim=n_factors, mode="sum") self.dropout = nn.Dropout(dropout_rate) + self.batch_key = batch_key @classmethod def from_dataset_schema( # TODO: decide about target aware schema @@ -60,18 +68,12 @@ def from_dataset_schema( # TODO: decide about target aware schema def forward(self, seqs: torch.Tensor, batch: tp.Dict[str, torch.Tensor]) -> torch.Tensor: """TODO.""" - # TODO: check correctness and remove offsets from batch - b, l, f = seqs.shape - offsets = batch["context_cat_offsets"].view(-1) - offsets = torch.cat([torch.zeros(1, dtype=offsets.dtype, device=offsets.device), offsets]) - offsets = offsets.cumsum(dim=0)[:-1] - - inputs = batch["context_cat_inputs"] - new_inputs = inputs.view(b * l, -1) - context_embs = self.embedding_bag(input=new_inputs) + batch_size, session_max_len, n_factors = seqs.shape + inputs = batch[self.batch_key].view(batch_size * session_max_len, -1) + context_embs = self.embedding_bag(input=inputs) context_embs = self.dropout(context_embs) - context_embs = context_embs.view(b, l, f) - return seqs + context_embs + context_embs = context_embs.view(batch_size, session_max_len, n_factors) + return context_embs @property def out_dim(self) -> int: diff --git a/rectools/models/nn/transformers/data_preparator.py b/rectools/models/nn/transformers/data_preparator.py index 4c7de9a3..46bd21f8 100644 --- a/rectools/models/nn/transformers/data_preparator.py +++ b/rectools/models/nn/transformers/data_preparator.py @@ -103,8 +103,7 @@ def from_interactions( # Combine all feature values into a single "inputs" column interactions["context_cat_inputs"] = features_mapped[list(mappings.keys())].values # .tolist() - interactions["context_cat_offsets"] = len(mappings.keys()) - extra_cols.extend(["context_cat_inputs", "context_cat_offsets"]) + extra_cols.extend(["context_cat_inputs"]) sessions = ( interactions.sort_values(Columns.Datetime, kind="stable") @@ -289,7 +288,9 @@ def process_dataset_train(self, dataset: Dataset) -> None: val_interactions = interactions[interactions[Columns.User].isin(val_targets[Columns.User].unique())].copy() val_interactions[Columns.Weight] = 0 val_interactions = pd.concat([val_interactions, val_targets], axis=0) - self.val_interactions = Interactions.from_raw(val_interactions, user_id_map, item_id_map).df + self.val_interactions = Interactions.from_raw( + val_interactions, user_id_map, item_id_map, keep_extra_cols=True + ).df def _init_extra_token_ids(self) -> None: extra_token_ids = self.item_id_map.convert_to_internal(self.item_extra_tokens) @@ -451,18 +452,18 @@ def transform_dataset_i2i(self, dataset: Dataset) -> Dataset: def _collate_fn_train( self, - batch: tp.List[tp.Tuple[tp.List[int], tp.List[float]]], + batch: tp.List[BatchElement], ) -> tp.Dict[str, torch.Tensor]: raise NotImplementedError() def _collate_fn_val( self, - batch: tp.List[tp.Tuple[tp.List[int], tp.List[float]]], + batch: tp.List[BatchElement], ) -> tp.Dict[str, torch.Tensor]: raise NotImplementedError() def _collate_fn_recommend( self, - batch: tp.List[tp.Tuple[tp.List[int], tp.List[float]]], + batch: tp.List[BatchElement], ) -> tp.Dict[str, torch.Tensor]: raise NotImplementedError() diff --git a/rectools/models/nn/transformers/sasrec.py b/rectools/models/nn/transformers/sasrec.py index 1d1354c7..48becf37 100644 --- a/rectools/models/nn/transformers/sasrec.py +++ b/rectools/models/nn/transformers/sasrec.py @@ -13,7 +13,7 @@ # limitations under the License. import typing as tp -from typing import Dict, List, Tuple +from typing import Dict, List import numpy as np import torch @@ -78,7 +78,6 @@ class SASRecDataPreparator(TransformerDataPreparatorBase): """ train_session_max_len_addition: int = 1 - non_pad_payload_keys: List[str] = [] # ["context_cat_offsets"] def _collate_fn_train( self, @@ -94,7 +93,7 @@ def _collate_fn_train( y = np.zeros((batch_size, self.session_max_len)) yw = np.zeros((batch_size, self.session_max_len)) - payloads_keys = [key for key in batch[0][2].keys() if key not in self.non_pad_payload_keys] + payloads_keys = batch[0][2].keys() train_payloads = np.zeros((len(payloads_keys), batch_size, self.session_max_len)) for i, (ses, ses_weights, payloads) in enumerate(batch): @@ -109,8 +108,6 @@ def _collate_fn_train( "yw": torch.FloatTensor(yw), } payloads_dict = {key: torch.LongTensor(train_payloads[j]) for j, key in enumerate(payloads_keys)} - for key in self.non_pad_payload_keys: - payloads_dict[key] = torch.LongTensor(payloads[key]) batch_dict.update(payloads_dict) if self.negative_sampler is not None: batch_dict["negatives"] = self.negative_sampler.get_negatives( @@ -118,7 +115,10 @@ def _collate_fn_train( ) return batch_dict - def _collate_fn_val(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[str, torch.Tensor]: + def _collate_fn_val( + self, + batch: List[BatchElement], + ) -> Dict[str, torch.Tensor]: batch_size = len(batch) x = np.zeros((batch_size, self.session_max_len)) y = np.zeros((batch_size, 1)) # Only leave-one-strategy is supported for losses @@ -127,8 +127,10 @@ def _collate_fn_val(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[st payloads_keys = batch[0][2].keys() train_payloads = np.zeros((len(payloads_keys), batch_size, self.session_max_len)) - for i, (ses, ses_weights, payloads) in enumerate(batch): - input_mask = ses_weights == 0 + for i, (ses_list, ses_weights_list, payloads) in enumerate(batch): + input_mask = np.array(ses_weights_list) == 0 + ses = np.array(ses_list) + ses_weights = np.array(ses_weights_list) input_session = ses[input_mask] # ses: [session_len] -> x[i]: [session_max_len] @@ -136,7 +138,9 @@ def _collate_fn_val(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[st y[i, -1:] = ses[~input_mask][0] # y[i]: [1] take only first target for leave-one-strategy yw[i, -1:] = ses_weights[~input_mask][0] # yw[i]: [1] for j, key in enumerate(payloads_keys): - train_payloads[j, i, -len(input_session) :] = payloads[key][input_mask][-self.session_max_len :] + train_payloads[j, i, -len(input_session) :] = np.array(payloads[key])[input_mask][ + -self.session_max_len : + ] batch_dict = {"x": torch.LongTensor(x), "y": torch.LongTensor(y), "yw": torch.FloatTensor(yw)} payloads_dict = {key: torch.LongTensor(train_payloads[j]) for j, key in enumerate(payloads_keys)} @@ -147,7 +151,10 @@ def _collate_fn_val(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[st ) return batch_dict - def _collate_fn_recommend(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[str, torch.Tensor]: + def _collate_fn_recommend( + self, + batch: List[BatchElement], + ) -> Dict[str, torch.Tensor]: """Right truncation, left padding to session_max_len""" x = np.zeros((len(batch), self.session_max_len)) payloads_keys = batch[0][2].keys() @@ -156,7 +163,7 @@ def _collate_fn_recommend(self, batch: List[Tuple[List[int], List[float]]]) -> D x[i, -len(ses) :] = ses[-self.session_max_len :] for j, key in enumerate(payloads_keys): train_payloads[j, i, -len(ses) :] = payloads[key][-self.session_max_len :] - batch_dict = {"x": torch.LongTensor(x)} + batch_dict: Dict[str, torch.Tensor] = {"x": torch.LongTensor(x)} payloads_dict = {key: torch.LongTensor(train_payloads[j]) for j, key in enumerate(payloads_keys)} batch_dict.update(payloads_dict) return batch_dict diff --git a/rectools/models/nn/transformers/torch_backbone.py b/rectools/models/nn/transformers/torch_backbone.py index 4302e094..7e739855 100644 --- a/rectools/models/nn/transformers/torch_backbone.py +++ b/rectools/models/nn/transformers/torch_backbone.py @@ -248,7 +248,7 @@ def encode_sessions(self, batch: tp.Dict[str, torch.Tensor], item_embs: torch.Te seqs = item_embs[sessions] # [batch_size, session_max_len, n_factors] if self.context_net is not None: - seqs = self.context_net(seqs, batch) + seqs += self.context_net(seqs, batch) seqs = self.pos_encoding_layer(seqs) seqs = self.emb_dropout(seqs) diff --git a/tests/models/nn/transformers/test_bert4rec.py b/tests/models/nn/transformers/test_bert4rec.py index 140389aa..f89e82ee 100644 --- a/tests/models/nn/transformers/test_bert4rec.py +++ b/tests/models/nn/transformers/test_bert4rec.py @@ -34,7 +34,7 @@ TransformerLightningModule, ) from rectools.models.nn.transformers.bert4rec import MASKING_VALUE, BERT4RecDataPreparator, ValMaskCallable -from rectools.models.nn.transformers.data_preparator import InitKwargs +from rectools.models.nn.transformers.data_preparator import BatchElement, InitKwargs from rectools.models.nn.transformers.negative_sampler import CatalogUniformSampler, TransformerNegativeSamplerBase from rectools.models.nn.transformers.similarity import DistanceSimilarityModule from rectools.models.nn.transformers.torch_backbone import TransformerTorchBackbone @@ -640,13 +640,13 @@ def __init__( def _collate_fn_train( self, - batch: tp.List[tp.Tuple[tp.List[int], tp.List[float]]], + batch: tp.List[BatchElement], ) -> tp.Dict[str, torch.Tensor]: batch_size = len(batch) x = np.zeros((batch_size, self.session_max_len)) y = np.zeros((batch_size, self.session_max_len)) yw = np.zeros((batch_size, self.session_max_len)) - for i, (ses, ses_weights) in enumerate(batch): + for i, (ses, ses_weights, _) in enumerate(batch): y[i, -self.n_last_targets] = ses[-self.n_last_targets] yw[i, -self.n_last_targets] = ses_weights[-self.n_last_targets] x[i, -len(ses) :] = ses From bdb59c1cceb76bec52f8855ca6682ac6a128b1a4 Mon Sep 17 00:00:00 2001 From: Daria Tikhonovich Date: Tue, 3 Jun 2025 22:47:36 +0300 Subject: [PATCH 3/3] up todo --- rectools/models/nn/transformers/context_net.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rectools/models/nn/transformers/context_net.py b/rectools/models/nn/transformers/context_net.py index 86b0159d..f2ab4b2f 100644 --- a/rectools/models/nn/transformers/context_net.py +++ b/rectools/models/nn/transformers/context_net.py @@ -77,5 +77,5 @@ def forward(self, seqs: torch.Tensor, batch: tp.Dict[str, torch.Tensor]) -> torc @property def out_dim(self) -> int: - """Return categorical item embedding output dimension.""" + """Return output dimension.""" return self.embedding_bag.embedding_dim