Source code for elliot.recommender.base_recommender_model

"""
Module description:

"""
import inspect
import logging as pylog
import os

import numpy as np
import random

from elliot.evaluation.evaluator import Evaluator
from elliot.utils.folder import build_model_folder

__version__ = '0.3.1'
__author__ = 'Vito Walter Anelli, Claudio Pomo'
__email__ = 'vitowalter.anelli@poliba.it, claudio.pomo@poliba.it'

from abc import ABC
from abc import abstractmethod
from functools import wraps
from types import SimpleNamespace
from elliot.utils import logging
from elliot.recommender.early_stopping import EarlyStopping


[docs]class BaseRecommenderModel(ABC): def __init__(self, data, config, params, *args, **kwargs): """ This class represents a recommender model. You can load a pretrained model by specifying its checkpoint path and use it for training/testing purposes. Args: data: data loader object params: dictionary with all parameters """ self._data = data self._config = config self._params = params self._negative_sampling = hasattr(data.config, "negative_sampling") self._restore = getattr(self._params.meta, "restore", False) _cutoff_k = getattr(data.config.evaluation, "cutoffs", [data.config.top_k]) _cutoff_k = _cutoff_k if isinstance(_cutoff_k, list) else [_cutoff_k] _first_metric = data.config.evaluation.simple_metrics[0] if data.config.evaluation.simple_metrics else "" _default_validation_k = _cutoff_k[0] self._validation_metric = getattr(self._params.meta, "validation_metric", _first_metric + "@" + str(_default_validation_k)).split("@") if self._validation_metric[0].lower() not in [m.lower() for m in data.config.evaluation.simple_metrics]: raise Exception("Validation metric must be in the list of simple metrics") self._validation_k = int(self._validation_metric[1]) if len(self._validation_metric) > 1 else _cutoff_k[0] if self._validation_k not in _cutoff_k: raise Exception("Validation cutoff must be in general cutoff values") self._validation_metric = self._validation_metric[0] self._save_weights = getattr(self._params.meta, "save_weights", False) self._save_recs = getattr(self._params.meta, "save_recs", False) self._verbose = getattr(self._params.meta, "verbose", None) self._validation_rate = getattr(self._params.meta, "validation_rate", 1) self._optimize_internal_loss = getattr(self._params.meta, "optimize_internal_loss", False) self._epochs = int(getattr(self._params, "epochs", 2)) self._seed = getattr(self._params, "seed", 42) self._early_stopping = EarlyStopping(SimpleNamespace(**getattr(self._params, "early_stopping", {})), self._validation_metric, self._validation_k, _cutoff_k, data.config.evaluation.simple_metrics) self._iteration = 0 if self._epochs < self._validation_rate: raise Exception(f"The first validation epoch ({self._validation_rate}) " f"is later than the overall number of epochs ({self._epochs}).") self._batch_size = getattr(self._params, "batch_size", -1) self.best_metric_value = 0 self._losses = [] self._results = [] self._params_list = []
[docs] def get_base_params_shortcut(self): return "_".join([str(k) + "=" + str(v).replace(".", "$") for k, v in dict({"seed": self._seed, "e": self._epochs, "bs": self._batch_size}).items() ])
[docs] def get_params_shortcut(self): return "_".join([str(p[2])+"="+ str(p[5](getattr(self, p[0])) if p[5] else getattr(self, p[0])).replace(".", "$") for p in self._params_list])
[docs] def autoset_params(self): """ Define Parameters as tuples: (variable_name, public_name, shortcut, default, reading_function, printing_function) Example: self._params_list = [ ("_similarity", "similarity", "sim", "cosine", None, None), ("_user_profile_type", "user_profile", "up", "tfidf", None, None), ("_item_profile_type", "item_profile", "ip", "tfidf", None, None), ("_mlpunits", "mlp_units", "mlpunits", "(1,2,3)", lambda x: list(make_tuple(x)), lambda x: str(x).replace(",", "-")), ] """ self.logger.info("Loading parameters") for variable_name, public_name, shortcut, default, reading_function, _ in self._params_list: if reading_function is None: setattr(self, variable_name, getattr(self._params, public_name, default)) else: setattr(self, variable_name, reading_function(getattr(self._params, public_name, default))) self.logger.info(f"Parameter {public_name} set to {getattr(self, variable_name)}") if not self._params_list: self.logger.info("No parameters defined")
@staticmethod def _batch_remove(original_str: str, char_list): for c in char_list: original_str = original_str.replace(c, "") return original_str
[docs] @abstractmethod def train(self): pass
[docs] @abstractmethod def get_recommendations(self, *args): pass
[docs] @abstractmethod def get_loss(self): pass
[docs] @abstractmethod def get_params(self): pass
[docs] @abstractmethod def get_results(self): pass
[docs]def init_charger(init): @wraps(init) def new_init(self, *args, **kwargs): BaseRecommenderModel.__init__(self, *args, **kwargs) package_name = inspect.getmodule(self).__package__ rec_name = f"external.{self.__class__.__name__}" if "external" in package_name else self.__class__.__name__ self.logger = logging.get_logger_model(rec_name, pylog.CRITICAL if self._config.config_test else pylog.DEBUG) np.random.seed(self._seed) random.seed(self._seed) self._nprandom = np.random self._random = random self._num_items = self._data.num_items self._num_users = self._data.num_users init(self, *args, **kwargs) self.evaluator = Evaluator(self._data, self._params) self._params.name = self.name build_model_folder(self._config.path_output_rec_weight, self.name) self._saving_filepath = os.path.abspath(os.sep.join([self._config.path_output_rec_weight, self.name, f"best-weights-{self.name}"])) return new_init