Off-Policy Learning in Two-stage Recommender Systems

Many real-world recommender systems need to be highly scalable: matching millions of items with billions of users, with milliseconds latency. The scalability requirement has led to widely used two-stage recommender systems, consisting of efficient candidate generation model(s) in the first stage and a more powerful ranking model in the second stage.

Logged user feedback, e.g., user clicks or dwell time, are often used to build both candidate generation and ranking models for recommender systems. While it’s easy to collect large amount of such data, they are inherently biased because the feedback can only be observed on items recommended by the previous systems. Recently, off-policy correction on such biases have attracted increasing interest in the field of recommender system research. However, most existing work either assumed that the recommender system is a single-stage system or only studied how to apply off-policy correction to the candidate generation stage of the system without explicitly considering the interactions between the two stages.

In this work, we propose a two-stage off-policy policy gradient method, and showcase that ignoring the interaction between the two stages leads to a sub-optimal policy in two-stage recommender systems. The proposed method explicitly takes into account the ranking model when training the candidate generation model, which helps improve the performance of the whole system. We conduct experiments on real-world datasets with large item space and demonstrate the effectiveness of our proposed method.

Pseudo code

Model structure

Training model

  1. The simulation model - divides MovieLens-1M into training set, validation set and test set at 3:1:1.

  2. Behavior strategy model and ranking model - Use 10,000 user-item pairs randomly generated by the simulation model to train the behavior strategy model, and then obtain a bandit data set by sampling the top-5 items of each user predicted by the behavior strategy model. Train the ranking model based on this data set, divide 2000 users as the verification set, and 4000 users as the test set.

  3. Candidate generation model - Set the optimizer to AdaGrad, the initial learning rate is 0.05, and the weight limit parameters of 1-IPS and 2-IPS are set c1 = 10, c2 = 0.01 c_1 = 10, c_2=0.01. For each training method, we trained 20 candidate generation models initialized with different random seeds. The early stopping method is applied in both one-stage evaluation and two-stage evaluation.

Imports

import os
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict

Model

NUM_ITEMS = 3883
NUM_YEARS = 81
NUM_GENRES = 18
NUM_USERS = 6040
NUM_OCCUPS = 21
NUM_AGES = 7
NUM_ZIPS = 3439
class ItemRep(nn.Module):
    """Item representation layer."""

    def __init__(self, item_emb_size=10, year_emb_size=5, genre_hidden=5):
        super(ItemRep, self).__init__()
        self.item_embedding = nn.Embedding(
            NUM_ITEMS + 1, item_emb_size, padding_idx=0)
        self.year_embedding = nn.Embedding(NUM_YEARS, year_emb_size)
        self.genre_linear = nn.Linear(NUM_GENRES, genre_hidden)
        self.rep_dim = item_emb_size + year_emb_size + genre_hidden

    def forward(self, categorical_feats, real_feats):
        out = torch.cat(
            [
                self.item_embedding(categorical_feats[:, 0]),
                self.year_embedding(categorical_feats[:, 1]),
                self.genre_linear(real_feats)
            ],
            dim=1)
        return out
class UserRep(nn.Module):
    """User representation layer."""

    def __init__(self, user_emb_size=10, feature_emb_size=5):
        super(UserRep, self).__init__()
        self.user_embedding = nn.Embedding(
            NUM_USERS + 1, user_emb_size, padding_idx=0)
        self.gender_embedding = nn.Embedding(2, feature_emb_size)
        self.age_embedding = nn.Embedding(NUM_AGES, feature_emb_size)
        self.occup_embedding = nn.Embedding(NUM_OCCUPS, feature_emb_size)
        self.zip_embedding = nn.Embedding(NUM_ZIPS, feature_emb_size)
        self.rep_dim = user_emb_size + feature_emb_size * 4

    def forward(self, categorical_feats, real_feats=None):
        reps = [
            self.user_embedding(categorical_feats[:, 0]),
            self.gender_embedding(categorical_feats[:, 1]),
            self.age_embedding(categorical_feats[:, 2]),
            self.occup_embedding(categorical_feats[:, 3]),
            self.zip_embedding(categorical_feats[:, 4])
        ]
        out = torch.cat(reps, dim=1)
        return out
class ImpressionSimulator(nn.Module):
    """Simulator model that predicts the outcome of impression."""

    def __init__(self, hidden=100, use_impression_feats=False):
        super(ImpressionSimulator, self).__init__()
        self.user_rep = UserRep()
        self.item_rep = ItemRep()
        self.use_impression_feats = use_impression_feats
        input_dim = self.user_rep.rep_dim + self.item_rep.rep_dim
        if use_impression_feats:
            input_dim += 1
        self.linear = nn.Sequential(
            nn.Linear(input_dim, hidden),
            nn.ReLU(), nn.Linear(hidden, 50), nn.ReLU(), nn.Linear(50, 1))

    def forward(self, user_feats, item_feats, impression_feats=None):
        users = self.user_rep(**user_feats)
        items = self.item_rep(**item_feats)
        inputs = torch.cat([users, items], dim=1)
        if self.use_impression_feats:
            inputs = torch.cat([inputs, impression_feats["real_feats"]], dim=1)
        return self.linear(inputs).squeeze()
class Nominator(nn.Module):
    """Two tower nominator model."""

    def __init__(self):
        super(Nominator, self).__init__()
        self.item_rep = ItemRep()
        self.user_rep = UserRep()
        self.linear = nn.Linear(self.user_rep.rep_dim, self.item_rep.rep_dim)
        self.binary = True

    def forward(self, user_feats, item_feats):
        users = self.linear(F.relu(self.user_rep(**user_feats)))
        users = torch.unsqueeze(users, 2)  # (b, h) -> (b, h, 1)
        items = self.item_rep(**item_feats)
        if self.binary:
            items = torch.unsqueeze(items, 1)  # (b, h) -> (b, 1, h)
        else:
            items = torch.unsqueeze(items, 0).expand(users.size(0), -1,
                                                     -1)  # (c, h) -> (b, c, h)
        logits = torch.bmm(items, users).squeeze()
        return logits

    def set_binary(self, binary=True):
        self.binary = binary
class Ranker(nn.Module):
    """Ranker model."""

    def __init__(self):
        super(Ranker, self).__init__()
        self.item_rep = ItemRep()
        self.user_rep = UserRep()
        self.linear = nn.Linear(self.user_rep.rep_dim + 1,
                                self.item_rep.rep_dim)
        self.binary = True

    def forward(self, user_feats, item_feats, impression_feats):
        users = self.user_rep(**user_feats)
        context_users = torch.cat(
            [users, impression_feats["real_feats"]], dim=1)
        context_users = self.linear(context_users)
        context_users = torch.unsqueeze(context_users,
                                        2)  # (b, h) -> (b, h, 1)
        items = self.item_rep(**item_feats)
        if self.binary:
            items = torch.unsqueeze(items, 1)  # (b, h) -> (b, 1, h)
        else:
            items = torch.unsqueeze(items, 0).expand(
                users.size(0), -1, -1)  # (c, h) -> (b, c, h), c=#items
        logits = torch.bmm(items, context_users).squeeze()
        return logits

    def set_binary(self, binary=True):
        self.binary = binary

Download data

!wget -q --show-progress https://files.grouplens.org/datasets/movielens/ml-1m.zip
!unzip ml-1m.zip
ml-1m.zip           100%[===================>]   5.64M  32.7MB/s    in 0.2s    
Archive:  ml-1m.zip
   creating: ml-1m/
  inflating: ml-1m/movies.dat        
  inflating: ml-1m/ratings.dat       
  inflating: ml-1m/README            
  inflating: ml-1m/users.dat         

Dataset

Create a torch dataset class for ML-1M; convert the label into binary labels (positive if rating > 3).

class MovieLensDataset(Dataset):
    def __init__(self, filepath, device="cuda:0"):
        self.device = device
        ratings, users, items = self.load_data(filepath)

        self.user_feats = {}
        self.item_feats = {}
        self.impression_feats = {}

        self.user_feats["categorical_feats"] = torch.LongTensor(
            users.values).to(device)
        self.item_feats["categorical_feats"] = torch.LongTensor(
            items.values[:, :2]).to(device)
        self.item_feats["real_feats"] = torch.FloatTensor(
            items.values[:, 2:]).to(device)
        self.impression_feats["user_ids"] = torch.LongTensor(
            ratings.values[:, 0]).to(device)
        self.impression_feats["item_ids"] = torch.LongTensor(
            ratings.values[:, 1]).to(device)
        self.impression_feats["real_feats"] = torch.FloatTensor(
            ratings.values[:, 3]).view(-1, 1).to(device)
        self.impression_feats["labels"] = torch.FloatTensor(
            ratings.values[:, 2]).to(device)

    def __len__(self):
        return len(self.impression_feats["user_ids"])

    def __getitem__(self, idx):
        labels = self.impression_feats["labels"][idx]
        feats = {}
        feats["impression_feats"] = {}
        feats["impression_feats"]["real_feats"] = self.impression_feats[
            "real_feats"][idx]
        user_id = self.impression_feats["user_ids"][idx]
        item_id = self.impression_feats["item_ids"][idx]
        feats["user_feats"] = {
            key: value[user_id - 1]
            for key, value in self.user_feats.items()
        }
        feats["item_feats"] = {
            key: value[item_id - 1]
            for key, value in self.item_feats.items()
        }
        return feats, labels

    def load_data(self, filepath):
        names = "UserID::MovieID::Rating::Timestamp".split("::")
        ratings = pd.read_csv(
            os.path.join(filepath, "ratings.dat"),
            sep="::",
            names=names,
            engine="python")
        ratings["Rating"] = (ratings["Rating"] > 3).astype(int)
        ratings["Timestamp"] = (
            ratings["Timestamp"] - ratings["Timestamp"].min()
        ) / float(ratings["Timestamp"].max() - ratings["Timestamp"].min())

        names = "UserID::Gender::Age::Occupation::Zip-code".split("::")
        users = pd.read_csv(
            os.path.join(filepath, "users.dat"),
            sep="::",
            names=names,
            engine="python")
        for i in range(1, users.shape[1]):
            users.iloc[:, i] = pd.factorize(users.iloc[:, i])[0]

        names = "MovieID::Title::Genres".split("::")
        Genres = [
            "Action", "Adventure", "Animation", "Children's", "Comedy",
            "Crime", "Documentary", "Drama", "Fantasy", "Film-Noir", "Horror",
            "Musical", "Mystery", "Romance", "Sci-Fi", "Thriller", "War",
            "Western"
        ]
        movies = pd.read_csv(
            os.path.join(filepath, "movies.dat"),
            sep="::",
            names=names,
            engine="python")
        movies["Year"] = movies["Title"].apply(lambda x: x[-5:-1])
        for genre in Genres:
            movies[genre] = movies["Genres"].apply(lambda x: genre in x)
        movies.iloc[:, 3] = pd.factorize(movies.iloc[:, 3])[0]
        movies.iloc[:, 4:] = movies.iloc[:, 4:].astype(float)
        movies = movies.loc[:, ["MovieID", "Year"] + Genres]
        movies.iloc[:, 2:] = movies.iloc[:, 2:].div(
            movies.iloc[:, 2:].sum(axis=1), axis=0)

        movie_id_map = {}
        for i in range(movies.shape[0]):
            movie_id_map[movies.loc[i, "MovieID"]] = i + 1

        movies["MovieID"] = movies["MovieID"].apply(lambda x: movie_id_map[x])
        ratings["MovieID"] = ratings["MovieID"].apply(
            lambda x: movie_id_map[x])

        self.NUM_ITEMS = len(movies.MovieID.unique())
        self.NUM_YEARS = len(movies.Year.unique())
        self.NUM_GENRES = movies.shape[1] - 2

        self.NUM_USERS = len(users.UserID.unique())
        self.NUM_OCCUPS = len(users.Occupation.unique())
        self.NUM_AGES = len(users.Age.unique())
        self.NUM_ZIPS = len(users["Zip-code"].unique())

        return ratings, users, movies
class SyntheticMovieLensDataset(Dataset):
    def __init__(self, filepath, simulator_path, synthetic_data_path, cut=0.764506,
                 device="cuda:0"):
        self.device = device
        self.cut = cut
        self.simulator = None

        ratings, users, items = self.load_data(filepath)

        self.user_feats = {}
        self.item_feats = {}
        self.impression_feats = {}

        self.user_feats["categorical_feats"] = torch.LongTensor(users.values)
        self.item_feats["categorical_feats"] = torch.LongTensor(
            items.values[:, :2])
        self.item_feats["real_feats"] = torch.FloatTensor(items.values[:, 2:])

        if os.path.exists(synthetic_data_path):
            self.impression_feats = torch.load(synthetic_data_path)
            self.impression_feats["labels"] = (
                self.impression_feats["label_probs"] >= cut).to(
                    dtype=torch.float32)
            print("loaded full_impression_feats.pt")
        else:
            print("generating impression_feats")
            self.simulator = ImpressionSimulator(use_impression_feats=True)
            self.simulator.load_state_dict(torch.load(simulator_path))
            self.simulator = self.simulator.to(device)

            impressions = self.get_full_impressions(ratings)

            self.impression_feats["user_ids"] = torch.LongTensor(
                impressions[:, 0])
            self.impression_feats["item_ids"] = torch.LongTensor(
                impressions[:, 1])
            self.impression_feats["real_feats"] = torch.FloatTensor(
                impressions[:, 2]).view(-1, 1)
            self.impression_feats["labels"] = torch.zeros_like(
                self.impression_feats["real_feats"])

            self.impression_feats["label_probs"] = self.generate_labels()
            self.impression_feats["labels"] = (
                self.impression_feats["label_probs"] >= cut).to(
                    dtype=torch.float32)

            torch.save(self.impression_feats, synthetic_data_path)
            print("saved impression_feats")

    def __len__(self):
        return len(self.impression_feats["user_ids"])

    def __getitem__(self, idx):
        labels = self.impression_feats["labels"][idx]
        feats = {}
        feats["impression_feats"] = {}
        feats["impression_feats"]["real_feats"] = self.impression_feats[
            "real_feats"][idx]
        user_id = self.impression_feats["user_ids"][idx]
        item_id = self.impression_feats["item_ids"][idx]
        feats["user_feats"] = {
            key: value[user_id - 1]
            for key, value in self.user_feats.items()
        }
        feats["item_feats"] = {
            key: value[item_id - 1]
            for key, value in self.item_feats.items()
        }
        return feats, labels

    def load_data(self, filepath):
        names = "UserID::MovieID::Rating::Timestamp".split("::")
        ratings = pd.read_csv(
            os.path.join(filepath, "ratings.dat"),
            sep="::",
            names=names,
            engine="python")
        ratings["Rating"] = (ratings["Rating"] > 3).astype(int)
        ratings["Timestamp"] = (
            ratings["Timestamp"] - ratings["Timestamp"].min()
        ) / float(ratings["Timestamp"].max() - ratings["Timestamp"].min())

        names = "UserID::Gender::Age::Occupation::Zip-code".split("::")
        users = pd.read_csv(
            os.path.join(filepath, "users.dat"),
            sep="::",
            names=names,
            engine="python")
        for i in range(1, users.shape[1]):
            users.iloc[:, i] = pd.factorize(users.iloc[:, i])[0]

        names = "MovieID::Title::Genres".split("::")
        Genres = [
            "Action", "Adventure", "Animation", "Children's", "Comedy",
            "Crime", "Documentary", "Drama", "Fantasy", "Film-Noir", "Horror",
            "Musical", "Mystery", "Romance", "Sci-Fi", "Thriller", "War",
            "Western"
        ]
        movies = pd.read_csv(
            os.path.join(filepath, "movies.dat"),
            sep="::",
            names=names,
            engine="python")
        movies["Year"] = movies["Title"].apply(lambda x: x[-5:-1])
        for genre in Genres:
            movies[genre] = movies["Genres"].apply(lambda x: genre in x)
        movies.iloc[:, 3] = pd.factorize(movies.iloc[:, 3])[0]
        movies.iloc[:, 4:] = movies.iloc[:, 4:].astype(float)
        movies = movies.loc[:, ["MovieID", "Year"] + Genres]
        movies.iloc[:, 2:] = movies.iloc[:, 2:].div(
            movies.iloc[:, 2:].sum(axis=1), axis=0)

        movie_id_map = {}
        for i in range(movies.shape[0]):
            movie_id_map[movies.loc[i, "MovieID"]] = i + 1

        movies["MovieID"] = movies["MovieID"].apply(lambda x: movie_id_map[x])
        ratings["MovieID"] = ratings["MovieID"].apply(
            lambda x: movie_id_map[x])

        self.NUM_ITEMS = len(movies.MovieID.unique())
        self.NUM_YEARS = len(movies.Year.unique())
        self.NUM_GENRES = movies.shape[1] - 2

        self.NUM_USERS = len(users.UserID.unique())
        self.NUM_OCCUPS = len(users.Occupation.unique())
        self.NUM_AGES = len(users.Age.unique())
        self.NUM_ZIPS = len(users["Zip-code"].unique())

        return ratings, users, movies

    def get_full_impressions(self, ratings):
        """Gets NUM_USERS x NUM_ITEMS impression features by iterating the user and item ids.
        The impression-level feature, i.e. the timestamp, is sampled from a normal distribution with
        mean and std as the empirical mean and std of each user's recorded timestamps in the real data.
        """
        timestamps = {}
        for i in range(len(ratings)):
            u_id = ratings.loc[i, "UserID"]
            timestamps[u_id] = timestamps.get(u_id, [])
            timestamps[u_id].append(ratings.loc[i, "Timestamp"])
        rs = np.random.RandomState(0)
        t_samples = []
        for i in range(self.NUM_USERS):
            u_id = i + 1
            t_samples.append(
                rs.normal(
                    loc=np.mean(timestamps[u_id]),
                    scale=np.std(timestamps[u_id]),
                    size=(self.NUM_ITEMS, )))
        t_samples = np.array(t_samples)

        impressions = []
        for i in range(self.NUM_USERS):
            for j in range(self.NUM_ITEMS):
                impressions.append([i + 1, j + 1, t_samples[i, j]])
        impressions = np.array(impressions)
        return impressions

    def to_device(self, data):
        if isinstance(data, torch.Tensor):
            return data.to(self.device)
        if isinstance(data, dict):
            transformed_data = {}
            for key in data:
                transformed_data[key] = self.to_device(data[key])
        elif type(data) == list:
            transformed_data = []
            for x in data:
                transformed_data.append(self.to_device(x))
        else:
            raise NotImplementedError(
                "Type {} not supported.".format(type(data)))
        return transformed_data

    def generate_labels(self):
        """Generates the binary labels using the simulator on every user-item pair."""
        with torch.no_grad():
            self.simulator.eval()
            preds = []
            for i in tqdm(range(len(self.impression_feats["labels"]) // 500)):
                feats, _ = self.__getitem__(
                    list(range(i * 500, (i + 1) * 500)))
                feats = self.to_device(feats)
                outputs = torch.sigmoid(self.simulator(**feats))
                preds += list(outputs.squeeze().cpu().numpy())
            if (i + 1) * 500 < len(self.impression_feats["labels"]):
                feats, _ = self.__getitem__(
                    list(
                        range((i + 1) * 500,
                              len(self.impression_feats["labels"]))))
                feats = self.to_device(feats)
                outputs = torch.sigmoid(self.simulator(**feats))
                preds += list(outputs.squeeze().cpu().numpy())
        return torch.FloatTensor(np.array(preds))

Metrics

class BaseMetric(object):
    def __init__(self, rel_threshold, k):
        self.rel_threshold = rel_threshold
        if np.isscalar(k):
            k = np.array([k])
        self.k = k

    def __len__(self):
        return len(self.k)

    def __call__(self, *args, **kwargs):
        raise NotImplementedError

    def _compute(self, *args, **kwargs):
        raise NotImplementedError
class PrecisionRecall(BaseMetric):
    def __init__(self, rel_threshold=0, k=10):
        super(PrecisionRecall, self).__init__(rel_threshold, k)

    def __len__(self):
        return 2 * len(self.k)

    def __str__(self):
        str_precision = [('Precision@%1.f' % x) for x in self.k]
        str_recall = [('Recall@%1.f' % x) for x in self.k]
        return (','.join(str_precision)) + ',' + (','.join(str_recall))

    def __call__(self, targets, predictions):
        precision, recall = zip(
            *[self._compute(targets, predictions, x) for x in self.k])
        result = np.concatenate((precision, recall), axis=0)
        return result

    def _compute(self, targets, predictions, k):
        predictions = predictions[:k]
        num_hit = len(set(predictions).intersection(set(targets)))

        return float(num_hit) / len(predictions), float(num_hit) / len(targets)
class MeanAP(BaseMetric):
    def __init__(self, rel_threshold=0, k=np.inf):
        super(MeanAP, self).__init__(rel_threshold, k)

    def __call__(self, targets, predictions):
        result = [self._compute(targets, predictions, x) for x in self.k]
        return np.array(result)

    def __str__(self):
        return ','.join([('MeanAP@%1.f' % x) for x in self.k])

    def _compute(self, targets, predictions, k):
        if len(predictions) > k:
            predictions = predictions[:k]

        score = 0.0
        num_hits = 0.0

        for i, p in enumerate(predictions):
            if p in targets and p not in predictions[:i]:
                num_hits += 1.0
                score += num_hits / (i + 1.0)

        if not list(targets):
            return 0.0

        return score / min(len(targets), k)
class NormalizedDCG(BaseMetric):
    def __init__(self, rel_threshold=0, k=10):
        super(NormalizedDCG, self).__init__(rel_threshold, k)

    def __call__(self, targets, predictions):
        result = [self._compute(targets, predictions, x) for x in self.k]
        return np.array(result)

    def __str__(self):
        return ','.join([('NDCG@%1.f' % x) for x in self.k])

    def _compute(self, targets, predictions, k):
        k = min(len(targets), k)

        if len(predictions) > k:
            predictions = predictions[:k]

        # compute idcg
        idcg = np.sum(1 / np.log2(np.arange(2, k + 2)))
        dcg = 0.0
        for i, p in enumerate(predictions):
            if p in targets:
                dcg += 1 / np.log2(i + 2)
        ndcg = dcg / idcg

        return ndcg


all_metrics = [PrecisionRecall(k=[1, 5, 10]), NormalizedDCG(k=[5, 10, 20])]
class Evaluator(object):
    """Evaluator for both one-stage and two-stage evaluations."""

    def __init__(self, u, a, simulator, syn):
        self.u = u
        self.a = a
        self.simulator = simulator
        self.syn = syn

        self.target_rankings = self.get_target_rankings()
        self.metrics = all_metrics

    def get_target_rankings(self):
        target_rankings = []
        with torch.no_grad():
            self.simulator.eval()
            for i in range(NUM_USERS):
                impression_ids = range(i * NUM_ITEMS, (i + 1) * NUM_ITEMS)
                feats, _ = self.syn[impression_ids]
                feats["impression_feats"]["real_feats"] = torch.mean(
                    feats["impression_feats"]["real_feats"],
                    dim=0,
                    keepdim=True).repeat([NUM_ITEMS, 1])
                feats = self.syn.to_device(feats)
                outputs = torch.sigmoid(self.simulator(**feats))
                user_target_ranking = (outputs >
                                       self.syn.cut).nonzero().view(-1)
                target_rankings.append(user_target_ranking.cpu().numpy())
        return target_rankings

    def one_stage_ranking_eval(self, logits, user_list):
        for i, user in enumerate(user_list):
            user_rated_items = self.a[self.u == user]
            logits[i, user_rated_items] = -np.inf
        sort_idx = torch.argsort(logits, dim=1, descending=True).cpu().numpy()
        # Init evaluation results.
        total_metrics_len = 0
        for metric in self.metrics:
            total_metrics_len += len(metric)

        total_val_metrics = np.zeros(
            [len(user_list), total_metrics_len], dtype=np.float32)
        valid_rows = []
        for i, user in enumerate(user_list):
            pred_ranking = sort_idx[i].tolist()
            target_ranking = self.target_rankings[user]
            if len(target_ranking) <= 0:
                continue
            metric_results = list()
            for j, metric in enumerate(self.metrics):
                result = metric(
                    targets=target_ranking, predictions=pred_ranking)
                metric_results.append(result)
            total_val_metrics[i, :] = np.concatenate(metric_results)
            valid_rows.append(i)
        # Average evaluation results by user.
        total_val_metrics = total_val_metrics[valid_rows]
        avg_val_metrics = (total_val_metrics.mean(axis=0)).tolist()
        # Summary evaluation results into a dict.
        ind, result = 0, OrderedDict()
        for metric in self.metrics:
            values = avg_val_metrics[ind:ind + len(metric)]
            if len(values) <= 1:
                result[str(metric)] = values
            else:
                for name, value in zip(str(metric).split(','), values):
                    result[name] = value
            ind += len(metric)
        return result

    def two_stage_ranking_eval(self, logits, ranker, user_list, k=30):
        sort_idx = torch.argsort(logits, dim=1, descending=True).cpu().numpy()
        topk_item_ids = []
        for i, user in enumerate(user_list):
            topk_item_ids.append([])
            for j in sort_idx[i]:
                if j not in self.a[self.u == user]:
                    topk_item_ids[-1].append(j)
                if len(topk_item_ids[-1]) == k:
                    break
        time_feats = self.syn.to_device(
            torch.mean(
                self.syn.impression_feats["real_feats"].view(
                    NUM_USERS, NUM_ITEMS),
                dim=1).view(-1, 1))
        # Init evaluation results.
        total_metrics_len = 0
        for metric in self.metrics:
            total_metrics_len += len(metric)

        total_val_metrics = np.zeros(
            [len(user_list), total_metrics_len], dtype=np.float32)
        valid_rows = []
        for i, user in enumerate(user_list):
            user_feats = {
                key: value[user].view(1, -1)
                for key, value in self.syn.user_feats.items()
            }
            item_feats = {
                key: value[topk_item_ids[i]]
                for key, value in self.syn.item_feats.items()
            }

            user_feats = self.syn.to_device(user_feats)
            item_feats = self.syn.to_device(item_feats)
            impression_feats = {"real_feats": time_feats[user].view(1, -1)}
            ranker_logits = ranker(user_feats, item_feats,
                                   impression_feats).view(1, -1)
            _, pred = ranker_logits.topk(k=k)
            pred = pred[0].cpu().numpy()
            pred_ranking = sort_idx[i][pred].tolist()
            target_ranking = self.target_rankings[user]
            if len(target_ranking) <= 0:
                continue
            metric_results = list()
            for j, metric in enumerate(self.metrics):
                result = metric(
                    targets=target_ranking, predictions=pred_ranking)
                metric_results.append(result)
            total_val_metrics[i, :] = np.concatenate(metric_results)
            valid_rows.append(i)
            # Average evaluation results by user.
        total_val_metrics = total_val_metrics[valid_rows]
        avg_val_metrics = (total_val_metrics.mean(axis=0)).tolist()
        # Summary evaluation results into a dict.
        ind, result = 0, OrderedDict()
        for metric in self.metrics:
            values = avg_val_metrics[ind:ind + len(metric)]
            if len(values) <= 1:
                result[str(metric)] = values
            else:
                for name, value in zip(str(metric).split(','), values):
                    result[name] = value
            ind += len(metric)
        return result

    def one_stage_eval(self, logits):
        sort_idx = torch.argsort(logits, dim=1, descending=True).cpu().numpy()
        impression_ids = []
        for i in range(NUM_USERS):
            for j in sort_idx[i]:
                if j not in self.a[self.u == i]:
                    break
            impression_ids.append(i * NUM_ITEMS + j)
        feats, labels = self.syn[impression_ids]
        feats["impression_feats"]["real_feats"] = torch.mean(
            self.syn.impression_feats["real_feats"].view(NUM_USERS, NUM_ITEMS),
            dim=1).view(-1, 1)
        with torch.no_grad():
            self.simulator.eval()
            feats = self.syn.to_device(feats)
            outputs = torch.sigmoid(self.simulator(**feats))
            return torch.mean(
                (outputs > self.syn.cut).to(dtype=torch.float32)).item()

    def two_stage_eval(self, logits, ranker, k=30):
        sort_idx = torch.argsort(logits, dim=1, descending=True).cpu().numpy()
        topk_item_ids = []
        for i in range(NUM_USERS):
            topk_item_ids.append([])
            for j in sort_idx[i]:
                if j not in self.a[self.u == i]:
                    topk_item_ids[-1].append(j)
                if len(topk_item_ids[-1]) == k:
                    break
        time_feats = self.syn.to_device(
            torch.mean(
                self.syn.impression_feats["real_feats"].view(
                    NUM_USERS, NUM_ITEMS),
                dim=1).view(-1, 1))
        recommneded = []
        for i in range(NUM_USERS):
            user_feats = {
                key: value[i].view(1, -1)
                for key, value in self.syn.user_feats.items()
            }
            item_feats = {
                key: value[topk_item_ids[i]]
                for key, value in self.syn.item_feats.items()
            }

            user_feats = self.syn.to_device(user_feats)
            item_feats = self.syn.to_device(item_feats)
            impression_feats = {"real_feats": time_feats[i].view(1, -1)}
            ranker_logits = ranker(user_feats, item_feats,
                                   impression_feats).view(1, -1)
            _, pred = torch.max(ranker_logits, 1)
            pred = pred.squeeze().item()
            recommneded.append(topk_item_ids[i][pred])

        impression_ids = []
        for i in range(NUM_USERS):
            impression_ids.append(i * NUM_ITEMS + recommneded[i])
        feats, labels = self.syn[impression_ids]
        feats["impression_feats"]["real_feats"] = torch.mean(
            self.syn.impression_feats["real_feats"].view(NUM_USERS, NUM_ITEMS),
            dim=1).view(-1, 1)
        with torch.no_grad():
            self.simulator.eval()
            feats = self.syn.to_device(feats)
            outputs = torch.sigmoid(self.simulator(**feats))
            return torch.mean(
                (outputs > self.syn.cut).to(dtype=torch.float32)).item()

Losses

def batch_select(mat, idx):
    mask = torch.arange(mat.size(1)).expand_as(mat).to(
        mat.device, dtype=torch.long)
    mask = (mask == idx.view(-1, 1))
    return torch.masked_select(mat, mask)


def unique_and_padding(mat, padding_idx, dim=-1):
    """Conducts unique operation along dim and pads to the same length."""
    samples, _ = torch.sort(mat, dim=dim)
    samples_roll = torch.roll(samples, -1, dims=dim)
    samples_diff = samples - samples_roll
    samples_diff[:,
                 -1] = 1  # deal with the edge case that there is only one unique sample in a row
    samples_mask = torch.bitwise_not(samples_diff == 0)  # unique mask
    samples *= samples_mask.to(dtype=samples.dtype)
    samples += (1 - samples_mask.to(dtype=samples.dtype)) * padding_idx
    samples, _ = torch.sort(samples, dim=dim)
    # shrink size to max unique length
    samples = torch.unique(samples, dim=dim)
    return samples
def loss_ce(logits, a, unused_p=None, unused_ranker_logits=None):
    """Cross entropy."""
    return -torch.mean(batch_select(F.log_softmax(logits, dim=1), a))
def loss_ips(logits,
             a,
             p,
             unused_ranker_logits=None,
             upper_limit=100,
             lower_limit=0.01):
    """IPS loss (one-stage)."""
    importance_weight = batch_select(F.softmax(logits.detach(), dim=1), a) / p
    importance_weight = torch.where(
        importance_weight > lower_limit, importance_weight,
        lower_limit * torch.ones_like(importance_weight))
    importance_weight = torch.where(
        importance_weight < upper_limit, importance_weight,
        upper_limit * torch.ones_like(importance_weight))
    importance_weight /= torch.mean(importance_weight)
    return -torch.mean(
        batch_select(F.log_softmax(logits, dim=1), a) * importance_weight)
def loss_2s(logits,
            a,
            p,
            ranker_logits,
            slate_sample_size=100,
            slate_size=30,
            temperature=np.e,
            alpha=1e-5,
            upper_limit=100,
            lower_limit=0.01):
    """Two stage loss."""
    num_logits = logits.size(1)
    rls = ranker_logits.detach()
    probs = F.softmax(logits.detach() / temperature, dim=1)
    log_probs = F.log_softmax(logits, dim=1)

    rls = torch.cat(
        [
            rls,
            torch.Tensor([float("-inf")]).to(rls.device).view(1, 1).expand(
                rls.size(0), 1)
        ],
        dim=1)
    log_probs = torch.cat(
        [log_probs,
         torch.zeros(log_probs.size(0), 1).to(log_probs.device)],
        dim=1)

    importance_weight = batch_select(F.softmax(logits.detach(), dim=1), a) / p
    importance_weight = torch.where(
        importance_weight > lower_limit, importance_weight,
        lower_limit * torch.ones_like(importance_weight))
    importance_weight = torch.where(
        importance_weight < upper_limit, importance_weight,
        upper_limit * torch.ones_like(importance_weight))
    importance_weight /= torch.mean(importance_weight)

    log_action_res = []
    sampled_slate_res = []
    for i in range(probs.size(0)):
        samples = torch.multinomial(
            probs[i], slate_sample_size * slate_size, replacement=True).view(
                slate_sample_size, slate_size)
        samples = torch.cat(
            [samples, a[i].view(1, 1).expand(samples.size(0), 1)], dim=1)
        samples = unique_and_padding(samples, num_logits)
        rp = F.softmax(
            F.embedding(samples, rls[i].view(-1, 1)).squeeze(-1), dim=1)
        lp = torch.sum(
            F.embedding(samples, log_probs[i].view(-1, 1)).squeeze(-1),
            dim=1) - log_probs[i, a[i]]
        sampled_slate_res.append(
            torch.mean(importance_weight[i] * rp[samples == a[i]] * lp))
        log_action_res.append(
            torch.mean(importance_weight[i] * rp[samples == a[i]]))
    loss = -torch.mean(
        torch.stack(log_action_res) * batch_select(log_probs, a))
    loss += -torch.mean(torch.stack(sampled_slate_res)) * alpha
    return loss

Main

!pip install -q torchnet
import argparse
import os

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from six.moves import cPickle as pickle
from torch.utils.data import DataLoader, Subset
from torchnet.meter import AUCMeter
parser = argparse.ArgumentParser()
parser.add_argument("--verbose", type=int, default=1, help="Verbose.")
parser.add_argument("--seed", type=int, default=0, help="Random seed.")
parser.add_argument("--loss_type", default="loss_ce")
parser.add_argument("--device", default="cuda:0")
parser.add_argument("--alpha", type=float, default=1e-3, help="Loss ratio.")
parser.add_argument("--lr", type=float, default=0.05, help="Learning rate.")
args = parser.parse_args(args={})

torch.manual_seed(0)
torch.cuda.manual_seed(0)

filepath = "./ml-1m"
device = args.device

dataset = MovieLensDataset(filepath, device=device)

NUM_ITEMS = dataset.NUM_ITEMS
NUM_YEARS = dataset.NUM_YEARS
NUM_GENRES = dataset.NUM_GENRES

NUM_USERS = dataset.NUM_USERS
NUM_OCCUPS = dataset.NUM_OCCUPS
NUM_AGES = dataset.NUM_AGES
NUM_ZIPS = dataset.NUM_ZIPS
simulator_path = os.path.join(filepath, "simulator.pt")
if os.path.exists(simulator_path):
    simulator = ImpressionSimulator(use_impression_feats=True)
    simulator.load_state_dict(torch.load(simulator_path))
    simulator.to(device)
    simulator.eval()
else:
    # train a simulator model on the original ML-1M dataset
    # the simulator will be used to generate synthetic labels later

    torch.manual_seed(0)
    torch.cuda.manual_seed(0)

    num_samples = len(dataset)
    train_loader = DataLoader(
        Subset(dataset, list(range(num_samples * 3 // 5))),
        batch_size=128,
        shuffle=True)
    val_loader = DataLoader(
        Subset(dataset,
               list(range(num_samples * 3 // 5, num_samples * 4 // 5))),
        batch_size=128)
    test_loader = DataLoader(
        Subset(dataset, list(range(num_samples * 4 // 5, num_samples))),
        batch_size=128)

    simulator = ImpressionSimulator(use_impression_feats=True).to(device)
    opt = torch.optim.Adagrad(
        simulator.parameters(), lr=0.05, weight_decay=1e-4)
    criterion = nn.BCEWithLogitsLoss()

    simulator.train()
    for epoch in range(4):
        print("---epoch {}---".format(epoch))
        for step, batch in enumerate(train_loader):
            feats, labels = batch
            logits = simulator(**feats)
            loss = criterion(logits, labels)

            opt.zero_grad()
            loss.backward()
            opt.step()

            if (step + 1) % 500 == 0:
                with torch.no_grad():
                    simulator.eval()
                    auc = AUCMeter()
                    for feats, labels in val_loader:
                        outputs = torch.sigmoid(simulator(**feats))
                        auc.add(outputs, labels)
                    print(step, auc.value()[0])
                    if auc.value()[0] > 0.735:
                        break
                simulator.train()

    simulator.to("cpu")
    torch.save(simulator.state_dict(), simulator_path)
---epoch 0---
499 0.6932387523620361
999 0.7135402368767171
1499 0.7215369956250887
1999 0.7251487885185414
2499 0.7250320801426405
2999 0.7272670074151073
3499 0.7281266442311184
3999 0.7294850286875231
4499 0.7296389614963857
---epoch 1---
499 0.7283547770191316
999 0.7295152100721607
1499 0.7300909216769196
1999 0.7302949709078976
2499 0.7307115513456678
2999 0.7303182560548731
3499 0.729832712732132
3999 0.7313203021611654
4499 0.7318606932964403
---epoch 2---
499 0.7313658487455917
999 0.731270933721071
1499 0.7307394064320014
1999 0.7323340285603053
2499 0.7322074757365146
2999 0.7321213073333425
3499 0.7330514084885155
3999 0.7328071902744475
4499 0.7332603719115534
---epoch 3---
499 0.7324623727995114
999 0.7323136933519381
1499 0.7327187807719571
1999 0.7329365712368654
2499 0.7318917432430467
2999 0.733582752772605
3499 0.7342934273983486
3999 0.7341365564445322
4499 0.7340031962421637
# create a torch dataset class that adopt the simulator and generate the synthetic dataset
synthetic_data_path = os.path.join(filepath, "full_impression_feats.pt")
syn = SyntheticMovieLensDataset(
    filepath, simulator_path, synthetic_data_path, device=device)

logging_policy_path = os.path.join(filepath, "logging_policy.pt")
if os.path.exists(logging_policy_path):
    logging_policy = Nominator()
    logging_policy.load_state_dict(torch.load(logging_policy_path))
    logging_policy.to(device)
    logging_policy.eval()
else:
    # train a logging policy using the synthetic dataset
    num_samples = len(syn)
    idx_list = list(range(num_samples))
    rs = np.random.RandomState(0)
    rs.shuffle(idx_list)
    train_idx = idx_list[:10000]
    val_idx = idx_list[10000:20000]
    test_idx = idx_list[-100000:]
    train_loader = DataLoader(
        Subset(syn, train_idx), batch_size=128, shuffle=True)
    val_loader = DataLoader(Subset(syn, val_idx), batch_size=128)
    test_loader = DataLoader(Subset(syn, test_idx), batch_size=128)

    logging_policy = Nominator().to(device)
    opt = torch.optim.Adagrad(
        logging_policy.parameters(), lr=0.05, weight_decay=1e-4)
    criterion = nn.BCEWithLogitsLoss()

    logging_policy.train()
    for epoch in range(40):
        print("---epoch {}---".format(epoch))
        for step, batch in enumerate(train_loader):
            feats, labels = batch
            feats = syn.to_device(feats)
            labels = syn.to_device(labels)
            logits = logging_policy(feats["user_feats"], feats["item_feats"])
            loss = criterion(logits, labels)

            opt.zero_grad()
            loss.backward()
            opt.step()

        with torch.no_grad():
            logging_policy.eval()
            auc = AUCMeter()
            for feats, labels in val_loader:
                feats = syn.to_device(feats)
                labels = syn.to_device(labels)
                outputs = torch.sigmoid(
                    logging_policy(feats["user_feats"], feats["item_feats"]))
                auc.add(outputs, labels)
            print(step, auc.value()[0])

            logging_policy.train()

    logging_policy.eval()

    logging_policy.to("cpu")
    torch.save(logging_policy.state_dict(), logging_policy_path)
    logging_policy.to(device)
generating impression_feats
100%|██████████| 46906/46906 [01:04<00:00, 726.33it/s]
saved impression_feats
---epoch 0---
78 0.7303382306338263
---epoch 1---
78 0.7581630234600971
---epoch 2---
78 0.7758875382286556
---epoch 3---
78 0.7947954816858157
---epoch 4---
78 0.8092959307857327
---epoch 5---
78 0.8209026242576345
---epoch 6---
78 0.8281561300331621
---epoch 7---
78 0.831641647559472
---epoch 8---
78 0.8359850391331326
---epoch 9---
78 0.838561799838773
---epoch 10---
78 0.8401704563395371
---epoch 11---
78 0.8411352317648504
---epoch 12---
78 0.8419039896279034
---epoch 13---
78 0.8420276300116679
---epoch 14---
78 0.8417600912427988
---epoch 15---
78 0.8424424910531916
---epoch 16---
78 0.8427186846489241
---epoch 17---
78 0.8426326119202265
---epoch 18---
78 0.8428287245904745
---epoch 19---
78 0.842670274683281
---epoch 20---
78 0.8428754226123425
---epoch 21---
78 0.8429700550599161
---epoch 22---
78 0.8431265076993719
---epoch 23---
78 0.8433620901844372
---epoch 24---
78 0.8435445073044836
---epoch 25---
78 0.8435860694950261
---epoch 26---
78 0.8435829309314381
---epoch 27---
78 0.843662155885035
---epoch 28---
78 0.8437115169305534
---epoch 29---
78 0.8438241247877666
---epoch 30---
78 0.843904205713251
---epoch 31---
78 0.843980862751185
---epoch 32---
78 0.8440994624116114
---epoch 33---
78 0.8442256707110387
---epoch 34---
78 0.8444167426579487
---epoch 35---
78 0.8444864568127943
---epoch 36---
78 0.8446461431238256
---epoch 37---
78 0.8447101507994208
---epoch 38---
78 0.8447210882179845
---epoch 39---
78 0.8447308843406981
def generate_bandit_samples(logging_policy, syn, k=5):
    """Generates partial-labeled bandit samples with the logging policy.
    Arguments:
        k: The number of items to be sampled for each user.
    """
    logging_policy.set_binary(False)
    with torch.no_grad():
        feats = {}
        feats["user_feats"] = syn.user_feats
        feats["item_feats"] = syn.item_feats
        feats = syn.to_device(feats)
        probs = F.softmax(logging_policy(**feats), dim=1)

    sampled_users = []
    sampled_actions = []
    sampled_probs = []
    sampled_rewards = []
    for i in range(probs.size(0)):
        sampled_users.append([i] * k)
        sampled_actions.append(
            torch.multinomial(probs[i], k).cpu().numpy().tolist())
        sampled_probs.append(
            probs[i, sampled_actions[-1]].cpu().numpy().tolist())
        sampled_rewards.append(syn.impression_feats["labels"][[
            i * probs.size(1) + j for j in sampled_actions[-1]
        ]].numpy().tolist())
    return np.array(sampled_users).reshape(-1), np.array(
        sampled_actions).reshape(-1), np.array(sampled_probs).reshape(
            -1), np.array(sampled_rewards).reshape(-1)


torch.manual_seed(0)
torch.cuda.manual_seed(0)

u, a, p, r = generate_bandit_samples(
    logging_policy, syn,
    k=5)  # u: user, a: item, p: logging policy probability, r: reward/label

simulator = simulator.to(device)
ev = Evaluator(u[r > 0], a[r > 0], simulator, syn)

all_user_feats = syn.to_device(syn.user_feats)
all_item_feats = syn.to_device(syn.item_feats)
all_impression_feats = syn.to_device({
    "real_feats":
    torch.mean(
        syn.impression_feats["real_feats"].view(NUM_USERS, NUM_ITEMS),
        dim=1).view(-1, 1)
})
# Split validation/test users.
num_val_users = 2000
val_user_list = list(range(0, num_val_users))
test_user_list = list(range(num_val_users, NUM_USERS))

test_item_feats = all_item_feats
test_user_feats = syn.to_device(
    {key: value[test_user_list]
     for key, value in all_user_feats.items()})
test_impression_feats = syn.to_device({
    key: value[test_user_list]
    for key, value in all_impression_feats.items()
})

val_item_feats = all_item_feats
val_user_feats = syn.to_device(
    {key: value[val_user_list]
     for key, value in all_user_feats.items()})
val_impression_feats = syn.to_device(
    {key: value[val_user_list]
     for key, value in all_impression_feats.items()})

ranker_path = os.path.join(filepath, "ranker.pt")

if os.path.exists(ranker_path):
    ranker = Ranker()
    ranker.load_state_dict(torch.load(ranker_path))
    ranker.to(device)
    ranker.eval()
    ranker.set_binary(False)
else:
    # train the ranker with binary cross-entropy
    torch.manual_seed(0)
    torch.cuda.manual_seed(0)

    batch_size = 128
    neg_sample_size = 29

    ranker = Ranker().to(device)
    opt = torch.optim.Adagrad(ranker.parameters(), lr=0.05, weight_decay=1e-4)
    criterion = nn.BCEWithLogitsLoss()

    rs = np.random.RandomState(0)
    ranker.train()
    for epoch in range(10):
        print("---epoch {}---".format(epoch))
        for step in range(len(u) // batch_size):
            user_list = u[step * batch_size:(step + 1) * batch_size]
            item_list = a[step * batch_size:(step + 1) * batch_size]
            user_feats = syn.to_device({
                key: value[user_list]
                for key, value in syn.user_feats.items()
            })
            item_feats = syn.to_device({
                key: value[item_list]
                for key, value in syn.item_feats.items()
            })
            impression_list = [
                user_id * NUM_ITEMS + item_id
                for user_id, item_id in zip(user_list, item_list)
            ]
            impression_feats = syn.to_device({
                "real_feats":
                syn.impression_feats["real_feats"][impression_list]
            })

            labels = torch.FloatTensor(
                r[step * batch_size:(step + 1) * batch_size]).to(device)

            logits = ranker(user_feats, item_feats, impression_feats)
            loss = criterion(logits, labels)

            opt.zero_grad()
            loss.backward()
            opt.step()

        with torch.no_grad():
            ranker.eval()
            ranker.set_binary(False)
            logits = ranker(all_user_feats, all_item_feats,
                            all_impression_feats)

            print(step, ev.one_stage_eval(logits))

            # Evaluate ranking metrics on validation users.
            logits = ranker(val_user_feats, val_item_feats,
                            val_impression_feats)
            print(step, ev.one_stage_ranking_eval(logits, val_user_list))

            ranker.train()
            ranker.set_binary(True)

    ranker.eval()
    ranker.set_binary(False)

    ranker.to("cpu")
    torch.save(ranker.state_dict(), ranker_path)
    ranker.to(device)
---epoch 0---
234 0.49751657247543335
234 OrderedDict([('Precision@1', 0.525853157043457), ('Precision@5', 0.37435382604599), ('Precision@10', 0.3275078535079956), ('Recall@1', 0.001950422883965075), ('Recall@5', 0.006817308254539967), ('Recall@10', 0.011337128467857838), ('NDCG@5', 0.40602830052375793), ('NDCG@10', 0.36229410767555237), ('NDCG@20', 0.3150430917739868)])
---epoch 1---
234 0.5798013210296631
234 OrderedDict([('Precision@1', 0.6003102660179138), ('Precision@5', 0.4798329174518585), ('Precision@10', 0.42228519916534424), ('Recall@1', 0.002812947379425168), ('Recall@5', 0.011755581945180893), ('Recall@10', 0.02120870165526867), ('NDCG@5', 0.5004847645759583), ('NDCG@10', 0.45322883129119873), ('NDCG@20', 0.39284855127334595)])
---epoch 2---
234 0.5889073014259338
234 OrderedDict([('Precision@1', 0.6106514930725098), ('Precision@5', 0.5765244960784912), ('Precision@10', 0.5148389339447021), ('Recall@1', 0.003235821146517992), ('Recall@5', 0.016686489805579185), ('Recall@10', 0.03442412614822388), ('NDCG@5', 0.5764947533607483), ('NDCG@10', 0.5345680713653564), ('NDCG@20', 0.4556241035461426)])
---epoch 3---
234 0.628311276435852
234 OrderedDict([('Precision@1', 0.6411582231521606), ('Precision@5', 0.6560497283935547), ('Precision@10', 0.6184099912643433), ('Recall@1', 0.0033987725619226694), ('Recall@5', 0.0258512943983078), ('Recall@10', 0.045876603573560715), ('NDCG@5', 0.6525105237960815), ('NDCG@10', 0.6282448172569275), ('NDCG@20', 0.5369541049003601)])
---epoch 4---
234 0.7625827789306641
234 OrderedDict([('Precision@1', 0.7487073540687561), ('Precision@5', 0.750673770904541), ('Precision@10', 0.6952475309371948), ('Recall@1', 0.004413578659296036), ('Recall@5', 0.033956073224544525), ('Recall@10', 0.05853657424449921), ('NDCG@5', 0.7535479068756104), ('NDCG@10', 0.715246319770813), ('NDCG@20', 0.6256369352340698)])
---epoch 5---
234 0.8622516989707947
234 OrderedDict([('Precision@1', 0.8490175604820251), ('Precision@5', 0.8347484469413757), ('Precision@10', 0.7592080235481262), ('Recall@1', 0.009471286088228226), ('Recall@5', 0.045272961258888245), ('Recall@10', 0.07097340375185013), ('NDCG@5', 0.840373158454895), ('NDCG@10', 0.7897650003433228), ('NDCG@20', 0.717039942741394)])
---epoch 6---
234 0.9102649092674255
234 OrderedDict([('Precision@1', 0.9089968800544739), ('Precision@5', 0.8815937042236328), ('Precision@10', 0.8184615969657898), ('Recall@1', 0.015791630372405052), ('Recall@5', 0.053664132952690125), ('Recall@10', 0.07895129173994064), ('NDCG@5', 0.8917895555496216), ('NDCG@10', 0.8513309359550476), ('NDCG@20', 0.7841229438781738)])
---epoch 7---
234 0.9369205236434937
234 OrderedDict([('Precision@1', 0.9441571831703186), ('Precision@5', 0.9020691514015198), ('Precision@10', 0.8512938618659973), ('Recall@1', 0.01766125112771988), ('Recall@5', 0.05773942917585373), ('Recall@10', 0.08265078067779541), ('NDCG@5', 0.9147428274154663), ('NDCG@10', 0.8829164505004883), ('NDCG@20', 0.8230650424957275)])
---epoch 8---
234 0.9519867897033691
234 OrderedDict([('Precision@1', 0.9565666913986206), ('Precision@5', 0.9139613509178162), ('Precision@10', 0.8667538166046143), ('Recall@1', 0.01800968125462532), ('Recall@5', 0.060318805277347565), ('Recall@10', 0.08457545191049576), ('NDCG@5', 0.9267275929450989), ('NDCG@10', 0.8979306817054749), ('NDCG@20', 0.844747006893158)])
---epoch 9---
234 0.9541391134262085
234 OrderedDict([('Precision@1', 0.9601861238479614), ('Precision@5', 0.9185118079185486), ('Precision@10', 0.8744063973426819), ('Recall@1', 0.018171625211834908), ('Recall@5', 0.061439476907253265), ('Recall@10', 0.08633137494325638), ('NDCG@5', 0.9311888217926025), ('NDCG@10', 0.9044811129570007), ('NDCG@20', 0.8586583733558655)])
u = u[r > 0]
a = a[r > 0]
p = p[r > 0]

batch_size = 128
check_metric = "Precision@10"

torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)

nominator = Nominator().to(device)
nominator.set_binary(False)

opt = torch.optim.Adagrad(
    nominator.parameters(), lr=args.lr, weight_decay=1e-4)

rs = np.random.RandomState(0)
nominator.train()
Nominator(
  (item_rep): ItemRep(
    (item_embedding): Embedding(3884, 10, padding_idx=0)
    (year_embedding): Embedding(81, 5)
    (genre_linear): Linear(in_features=18, out_features=5, bias=True)
  )
  (user_rep): UserRep(
    (user_embedding): Embedding(6041, 10, padding_idx=0)
    (gender_embedding): Embedding(2, 5)
    (age_embedding): Embedding(7, 5)
    (occup_embedding): Embedding(21, 5)
    (zip_embedding): Embedding(3439, 5)
  )
  (linear): Linear(in_features=30, out_features=20, bias=True)
)
# Init results.
best_epoch = 0
best_result = 0.0
val_results, test_results = [], []

if args.loss_type == "loss_2s":
    with torch.no_grad():
        ranker.eval()
        ranker.set_binary(False)
        ranker_logits = ranker(all_user_feats, all_item_feats,
                               all_impression_feats)

for epoch in range(20):
    print("---epoch {}---".format(epoch))
    for step in range(len(u) // batch_size):
        item_ids = torch.LongTensor(
            a[step * batch_size:(step + 1) * batch_size]).to(device)
        item_probs = torch.FloatTensor(
            p[step * batch_size:(step + 1) * batch_size]).to(device)

        user_ids = u[step * batch_size:(step + 1) * batch_size]
        user_feats = {
            key: value[user_ids]
            for key, value in syn.user_feats.items()
        }
        user_feats = syn.to_device(user_feats)

        logits = nominator(user_feats, val_item_feats)

        if args.loss_type == "loss_ce":
            loss = loss_ce(logits, item_ids, item_probs)
        elif args.loss_type == "loss_ips":
            loss = loss_ips(logits, item_ids, item_probs, upper_limit=10)
        elif args.loss_type == "loss_2s":
            batch_ranker_logits = F.embedding(
                torch.LongTensor(user_ids).to(device), ranker_logits)
            loss = loss_2s(
                logits,
                item_ids,
                item_probs,
                batch_ranker_logits,
                upper_limit=10,
                alpha=args.alpha)
        else:
            raise NotImplementedError(
                "{} not supported.".format(args.loss_type))

        opt.zero_grad()
        loss.backward()
        opt.step()

    with torch.no_grad():
        nominator.eval()

        logits = nominator(all_user_feats, all_item_feats)
        print("1 stage", ev.one_stage_eval(logits))
        print("2 stage", ev.two_stage_eval(logits, ranker))

        # Evaluate ranking metrics on validation users.
        logits = nominator(val_user_feats, val_item_feats)
        one_stage_results = ev.one_stage_ranking_eval(logits, val_user_list)
        print("1 stage (val)", one_stage_results)
        two_stage_results = ev.two_stage_ranking_eval(logits, ranker,
                                                      val_user_list)
        print("2 stage (val)", two_stage_results)
        val_results.append((one_stage_results, two_stage_results))
        # Log best epoch
        if two_stage_results[check_metric] > best_result:
            best_epoch = epoch
            best_result = two_stage_results[check_metric]
        # Evaluate ranking metrics on test users.
        logits = nominator(test_user_feats, test_item_feats)
        one_stage_results = ev.one_stage_ranking_eval(logits, test_user_list)
        print("1 stage (test)", one_stage_results)
        two_stage_results = ev.two_stage_ranking_eval(logits, ranker,
                                                      test_user_list)
        print("2 stage (test)", two_stage_results)
        test_results.append((one_stage_results, two_stage_results))

        nominator.train()
---epoch 0---
1 stage 0.8389073014259338
2 stage 0.931456983089447
1 stage (val) OrderedDict([('Precision@1', 0.7761116623878479), ('Precision@5', 0.7825230956077576), ('Precision@10', 0.7693895101547241), ('Recall@1', 0.0026108406018465757), ('Recall@5', 0.01443537138402462), ('Recall@10', 0.02800210937857628), ('NDCG@5', 0.7805129289627075), ('NDCG@10', 0.7720701694488525), ('NDCG@20', 0.7773204445838928)])
2 stage (val) OrderedDict([('Precision@1', 0.9152016639709473), ('Precision@5', 0.8799382448196411), ('Precision@10', 0.8639603853225708), ('Recall@1', 0.010427681729197502), ('Recall@5', 0.02969161979854107), ('Recall@10', 0.04909321665763855), ('NDCG@5', 0.8888923525810242), ('NDCG@10', 0.8768721222877502), ('NDCG@20', 0.8476952910423279)])
1 stage (test) OrderedDict([('Precision@1', 0.894406795501709), ('Precision@5', 0.8917496204376221), ('Precision@10', 0.8819892406463623), ('Recall@1', 0.003005631733685732), ('Recall@5', 0.015431685373187065), ('Recall@10', 0.030124178156256676), ('NDCG@5', 0.8914320468902588), ('NDCG@10', 0.8846309781074524), ('NDCG@20', 0.8876312971115112)])
2 stage (test) OrderedDict([('Precision@1', 0.9671432375907898), ('Precision@5', 0.9465765357017517), ('Precision@10', 0.9378731846809387), ('Recall@1', 0.006223659496754408), ('Recall@5', 0.022564038634300232), ('Recall@10', 0.0406784787774086), ('NDCG@5', 0.9511790871620178), ('NDCG@10', 0.9443507790565491), ('NDCG@20', 0.9283153414726257)])
---epoch 1---
1 stage 0.8367549777030945
2 stage 0.9365894198417664
1 stage (val) OrderedDict([('Precision@1', 0.7730093002319336), ('Precision@5', 0.780661404132843), ('Precision@10', 0.7693895101547241), ('Recall@1', 0.002592692384496331), ('Recall@5', 0.01431603729724884), ('Recall@10', 0.027908695861697197), ('NDCG@5', 0.7788269519805908), ('NDCG@10', 0.7716042399406433), ('NDCG@20', 0.7761374115943909)])
2 stage (val) OrderedDict([('Precision@1', 0.9255428910255432), ('Precision@5', 0.883350670337677), ('Precision@10', 0.8641155958175659), ('Recall@1', 0.012946750037372112), ('Recall@5', 0.03286115825176239), ('Recall@10', 0.052004944533109665), ('NDCG@5', 0.894393801689148), ('NDCG@10', 0.8804171681404114), ('NDCG@20', 0.851220965385437)])
1 stage (test) OrderedDict([('Precision@1', 0.8926511406898499), ('Precision@5', 0.889492392539978), ('Precision@10', 0.8817635774612427), ('Recall@1', 0.002984470222145319), ('Recall@5', 0.01524937804788351), ('Recall@10', 0.030079031363129616), ('NDCG@5', 0.8897153735160828), ('NDCG@10', 0.8841732740402222), ('NDCG@20', 0.8868218660354614)])
2 stage (test) OrderedDict([('Precision@1', 0.9699021577835083), ('Precision@5', 0.9467771649360657), ('Precision@10', 0.9376724362373352), ('Recall@1', 0.007032747846096754), ('Recall@5', 0.023379521444439888), ('Recall@10', 0.04146404191851616), ('NDCG@5', 0.9522501826286316), ('NDCG@10', 0.9452073574066162), ('NDCG@20', 0.9290862083435059)])
---epoch 2---
1 stage 0.8357616066932678
2 stage 0.9385761618614197
1 stage (val) OrderedDict([('Precision@1', 0.771458089351654), ('Precision@5', 0.7803509831428528), ('Precision@10', 0.7697516083717346), ('Recall@1', 0.002546904841437936), ('Recall@5', 0.014310559257864952), ('Recall@10', 0.027927899733185768), ('NDCG@5', 0.7782995104789734), ('NDCG@10', 0.7716609239578247), ('NDCG@20', 0.7768319845199585)])
2 stage (val) OrderedDict([('Precision@1', 0.9281282424926758), ('Precision@5', 0.8846949934959412), ('Precision@10', 0.8650981783866882), ('Recall@1', 0.013078119605779648), ('Recall@5', 0.03309616446495056), ('Recall@10', 0.05231418088078499), ('NDCG@5', 0.895997166633606), ('NDCG@10', 0.8816695213317871), ('NDCG@20', 0.8517394661903381)])
1 stage (test) OrderedDict([('Precision@1', 0.8918986916542053), ('Precision@5', 0.8886397480964661), ('Precision@10', 0.8822652101516724), ('Recall@1', 0.002973585156723857), ('Recall@5', 0.015205418691039085), ('Recall@10', 0.030113695189356804), ('NDCG@5', 0.8889814615249634), ('NDCG@10', 0.8843721747398376), ('NDCG@20', 0.886855959892273)])
2 stage (test) OrderedDict([('Precision@1', 0.9716578722000122), ('Precision@5', 0.9481315016746521), ('Precision@10', 0.9381741285324097), ('Recall@1', 0.007616504095494747), ('Recall@5', 0.024297403171658516), ('Recall@10', 0.0423288531601429), ('NDCG@5', 0.9539521932601929), ('NDCG@10', 0.9464597702026367), ('NDCG@20', 0.9301363229751587)])
---epoch 3---
1 stage 0.8339403867721558
2 stage 0.940562903881073
1 stage (val) OrderedDict([('Precision@1', 0.7683557271957397), ('Precision@5', 0.7795237898826599), ('Precision@10', 0.7704238891601562), ('Recall@1', 0.0025028539821505547), ('Recall@5', 0.014192001894116402), ('Recall@10', 0.028007399290800095), ('NDCG@5', 0.7769649624824524), ('NDCG@10', 0.7715820074081421), ('NDCG@20', 0.7766546607017517)])
2 stage (val) OrderedDict([('Precision@1', 0.9332988858222961), ('Precision@5', 0.8861426711082458), ('Precision@10', 0.8660286664962769), ('Recall@1', 0.01370152272284031), ('Recall@5', 0.034078482538461685), ('Recall@10', 0.05330389738082886), ('NDCG@5', 0.8979603052139282), ('NDCG@10', 0.88340163230896), ('NDCG@20', 0.8530087471008301)])
1 stage (test) OrderedDict([('Precision@1', 0.8906446099281311), ('Precision@5', 0.8880879282951355), ('Precision@10', 0.8825662732124329), ('Recall@1', 0.002958007389679551), ('Recall@5', 0.015176204033195972), ('Recall@10', 0.030142122879624367), ('NDCG@5', 0.8884179592132568), ('NDCG@10', 0.884431004524231), ('NDCG@20', 0.8868438005447388)])
2 stage (test) OrderedDict([('Precision@1', 0.9721595048904419), ('Precision@5', 0.9488838315010071), ('Precision@10', 0.938500165939331), ('Recall@1', 0.00781310349702835), ('Recall@5', 0.02463247813284397), ('Recall@10', 0.04265395551919937), ('NDCG@5', 0.9547868967056274), ('NDCG@10', 0.9471161961555481), ('NDCG@20', 0.9306148290634155)])
---epoch 4---
1 stage 0.8331125974655151
2 stage 0.9425497055053711
1 stage (val) OrderedDict([('Precision@1', 0.7652533650398254), ('Precision@5', 0.7786965370178223), ('Precision@10', 0.7709408402442932), ('Recall@1', 0.0024512740783393383), ('Recall@5', 0.014100495725870132), ('Recall@10', 0.028065374121069908), ('NDCG@5', 0.7758020758628845), ('NDCG@10', 0.7715065479278564), ('NDCG@20', 0.7763851284980774)])
2 stage (val) OrderedDict([('Precision@1', 0.9358841776847839), ('Precision@5', 0.8872801661491394), ('Precision@10', 0.8661839962005615), ('Recall@1', 0.013996284455060959), ('Recall@5', 0.03461095318198204), ('Recall@10', 0.05393322929739952), ('NDCG@5', 0.8994066715240479), ('NDCG@10', 0.8842594623565674), ('NDCG@20', 0.8535211086273193)])
1 stage (test) OrderedDict([('Precision@1', 0.890895426273346), ('Precision@5', 0.8875863552093506), ('Precision@10', 0.8828924298286438), ('Recall@1', 0.002969894791021943), ('Recall@5', 0.015152904205024242), ('Recall@10', 0.03018353506922722), ('NDCG@5', 0.887993335723877), ('NDCG@10', 0.8845545053482056), ('NDCG@20', 0.886780858039856)])
2 stage (test) OrderedDict([('Precision@1', 0.9739152193069458), ('Precision@5', 0.9496864676475525), ('Precision@10', 0.9387760162353516), ('Recall@1', 0.007960631512105465), ('Recall@5', 0.024898577481508255), ('Recall@10', 0.04285132512450218), ('NDCG@5', 0.9557564854621887), ('NDCG@10', 0.9477135539054871), ('NDCG@20', 0.9310080409049988)])
---epoch 5---
1 stage 0.8336092829704285
2 stage 0.942218542098999
1 stage (val) OrderedDict([('Precision@1', 0.765770435333252), ('Precision@5', 0.7783864140510559), ('Precision@10', 0.7712510228157043), ('Recall@1', 0.0024701899383217096), ('Recall@5', 0.014151088893413544), ('Recall@10', 0.02811473049223423), ('NDCG@5', 0.7756521701812744), ('NDCG@10', 0.7717127799987793), ('NDCG@20', 0.7761114239692688)])
2 stage (val) OrderedDict([('Precision@1', 0.9364012479782104), ('Precision@5', 0.888003945350647), ('Precision@10', 0.8662872314453125), ('Recall@1', 0.013898404315114021), ('Recall@5', 0.03511940315365791), ('Recall@10', 0.05434870347380638), ('NDCG@5', 0.8998069763183594), ('NDCG@10', 0.8843777179718018), ('NDCG@20', 0.8532311916351318)])
1 stage (test) OrderedDict([('Precision@1', 0.8913970589637756), ('Precision@5', 0.887134850025177), ('Precision@10', 0.8827671408653259), ('Recall@1', 0.0029800296761095524), ('Recall@5', 0.015131834894418716), ('Recall@10', 0.03018496185541153), ('NDCG@5', 0.8876475095748901), ('NDCG@10', 0.8844181895256042), ('NDCG@20', 0.8867219686508179)])
2 stage (test) OrderedDict([('Precision@1', 0.9731627702713013), ('Precision@5', 0.9494857788085938), ('Precision@10', 0.9387760162353516), ('Recall@1', 0.007924147881567478), ('Recall@5', 0.024823080748319626), ('Recall@10', 0.042789481580257416), ('NDCG@5', 0.9554441571235657), ('NDCG@10', 0.9475792646408081), ('NDCG@20', 0.9307677149772644)])
---epoch 6---
1 stage 0.8346026539802551
2 stage 0.9417218565940857
1 stage (val) OrderedDict([('Precision@1', 0.7652533650398254), ('Precision@5', 0.7777659296989441), ('Precision@10', 0.7713546752929688), ('Recall@1', 0.002462433883920312), ('Recall@5', 0.014112057164311409), ('Recall@10', 0.02813573367893696), ('NDCG@5', 0.7751555442810059), ('NDCG@10', 0.7717844247817993), ('NDCG@20', 0.7759582996368408)])
2 stage (val) OrderedDict([('Precision@1', 0.9353671073913574), ('Precision@5', 0.8887278437614441), ('Precision@10', 0.866390585899353), ('Recall@1', 0.013995282351970673), ('Recall@5', 0.035590916872024536), ('Recall@10', 0.05474210903048515), ('NDCG@5', 0.9002852439880371), ('NDCG@10', 0.8846437931060791), ('NDCG@20', 0.8533338904380798)])
1 stage (test) OrderedDict([('Precision@1', 0.8931527733802795), ('Precision@5', 0.8867837190628052), ('Precision@10', 0.8830179572105408), ('Recall@1', 0.0030013115610927343), ('Recall@5', 0.01514824852347374), ('Recall@10', 0.03022931143641472), ('NDCG@5', 0.8876431584358215), ('NDCG@10', 0.8846896886825562), ('NDCG@20', 0.886621356010437)])
2 stage (test) OrderedDict([('Precision@1', 0.9729119539260864), ('Precision@5', 0.9496863484382629), ('Precision@10', 0.9388262033462524), ('Recall@1', 0.007614111993461847), ('Recall@5', 0.024715635925531387), ('Recall@10', 0.04267028719186783), ('NDCG@5', 0.9553412199020386), ('NDCG@10', 0.9473593235015869), ('NDCG@20', 0.9304744601249695)])
---epoch 7---
1 stage 0.8350993394851685
2 stage 0.9408940672874451
1 stage (val) OrderedDict([('Precision@1', 0.7652533650398254), ('Precision@5', 0.7769388556480408), ('Precision@10', 0.771303117275238), ('Recall@1', 0.002458712086081505), ('Recall@5', 0.014063295908272266), ('Recall@10', 0.028185289353132248), ('NDCG@5', 0.7745227813720703), ('NDCG@10', 0.7716432809829712), ('NDCG@20', 0.7753005027770996)])
2 stage (val) OrderedDict([('Precision@1', 0.9322647452354431), ('Precision@5', 0.8887277841567993), ('Precision@10', 0.865925133228302), ('Recall@1', 0.013905792497098446), ('Recall@5', 0.035767413675785065), ('Recall@10', 0.05483044683933258), ('NDCG@5', 0.8999837636947632), ('NDCG@10', 0.8842496871948242), ('NDCG@20', 0.8529420495033264)])
1 stage (test) OrderedDict([('Precision@1', 0.8939051628112793), ('Precision@5', 0.8863823413848877), ('Precision@10', 0.8830680847167969), ('Recall@1', 0.003021263051778078), ('Recall@5', 0.015123428776860237), ('Recall@10', 0.03023369051516056), ('NDCG@5', 0.887448787689209), ('NDCG@10', 0.8847446441650391), ('NDCG@20', 0.886391818523407)])
2 stage (test) OrderedDict([('Precision@1', 0.9731627702713013), ('Precision@5', 0.9496863484382629), ('Precision@10', 0.9387257099151611), ('Recall@1', 0.00767667219042778), ('Recall@5', 0.024805083870887756), ('Recall@10', 0.04274982586503029), ('NDCG@5', 0.9554198384284973), ('NDCG@10', 0.9473903179168701), ('NDCG@20', 0.9302816390991211)])
---epoch 8---
1 stage 0.8354305028915405
2 stage 0.9399006962776184
1 stage (val) OrderedDict([('Precision@1', 0.765770435333252), ('Precision@5', 0.7759048342704773), ('Precision@10', 0.771199643611908), ('Recall@1', 0.0024883777368813753), ('Recall@5', 0.013925631530582905), ('Recall@10', 0.028112078085541725), ('NDCG@5', 0.7737619876861572), ('NDCG@10', 0.7714693546295166), ('NDCG@20', 0.7745940685272217)])
2 stage (val) OrderedDict([('Precision@1', 0.9307135343551636), ('Precision@5', 0.8879006505012512), ('Precision@10', 0.8653048872947693), ('Recall@1', 0.013744794763624668), ('Recall@5', 0.03534204140305519), ('Recall@10', 0.05429835617542267), ('NDCG@5', 0.8989336490631104), ('NDCG@10', 0.8833516836166382), ('NDCG@20', 0.8518505096435547)])
1 stage (test) OrderedDict([('Precision@1', 0.8941559791564941), ('Precision@5', 0.8859308362007141), ('Precision@10', 0.8831183314323425), ('Recall@1', 0.0030219820328056812), ('Recall@5', 0.015090204775333405), ('Recall@10', 0.030239716172218323), ('NDCG@5', 0.8871259689331055), ('NDCG@10', 0.8847541809082031), ('NDCG@20', 0.8861094117164612)])
2 stage (test) OrderedDict([('Precision@1', 0.9724103212356567), ('Precision@5', 0.9492348432540894), ('Precision@10', 0.9386504292488098), ('Recall@1', 0.007493751123547554), ('Recall@5', 0.02462335117161274), ('Recall@10', 0.04261213541030884), ('NDCG@5', 0.9548900723457336), ('NDCG@10', 0.9471115469932556), ('NDCG@20', 0.9297663569450378)])
---epoch 9---
1 stage 0.8354305028915405
2 stage 0.9394040107727051
1 stage (val) OrderedDict([('Precision@1', 0.7673215866088867), ('Precision@5', 0.7746638655662537), ('Precision@10', 0.7714064717292786), ('Recall@1', 0.002514956519007683), ('Recall@5', 0.013873514719307423), ('Recall@10', 0.028209399431943893), ('NDCG@5', 0.7729977965354919), ('NDCG@10', 0.7715849876403809), ('NDCG@20', 0.7738854885101318)])
2 stage (val) OrderedDict([('Precision@1', 0.9312306046485901), ('Precision@5', 0.8884176015853882), ('Precision@10', 0.8655635714530945), ('Recall@1', 0.01376135554164648), ('Recall@5', 0.03537794202566147), ('Recall@10', 0.0541960671544075), ('NDCG@5', 0.8995826244354248), ('NDCG@10', 0.8838357329368591), ('NDCG@20', 0.8522109389305115)])
1 stage (test) OrderedDict([('Precision@1', 0.8934035897254944), ('Precision@5', 0.8860311508178711), ('Precision@10', 0.8833188414573669), ('Recall@1', 0.003013131907209754), ('Recall@5', 0.015086286701261997), ('Recall@10', 0.030243460088968277), ('NDCG@5', 0.8870205283164978), ('NDCG@10', 0.884793758392334), ('NDCG@20', 0.8854647278785706)])
2 stage (test) OrderedDict([('Precision@1', 0.9714070558547974), ('Precision@5', 0.9491847157478333), ('Precision@10', 0.9384247064590454), ('Recall@1', 0.00736744562163949), ('Recall@5', 0.024592455476522446), ('Recall@10', 0.04252685606479645), ('NDCG@5', 0.9546566605567932), ('NDCG@10', 0.9468095302581787), ('NDCG@20', 0.9292915463447571)])
---epoch 10---
1 stage 0.8355960249900818
2 stage 0.9394040107727051
1 stage (val) OrderedDict([('Precision@1', 0.7688727974891663), ('Precision@5', 0.7746636271476746), ('Precision@10', 0.7718202471733093), ('Recall@1', 0.002541040303185582), ('Recall@5', 0.013961467891931534), ('Recall@10', 0.028273116797208786), ('NDCG@5', 0.773000955581665), ('NDCG@10', 0.7718640565872192), ('NDCG@20', 0.7727073431015015)])
2 stage (val) OrderedDict([('Precision@1', 0.9317476749420166), ('Precision@5', 0.8890381455421448), ('Precision@10', 0.8653049468994141), ('Recall@1', 0.01362884882837534), ('Recall@5', 0.03554679825901985), ('Recall@10', 0.05419434234499931), ('NDCG@5', 0.8999322652816772), ('NDCG@10', 0.8837052583694458), ('NDCG@20', 0.8518556952476501)])
1 stage (test) OrderedDict([('Precision@1', 0.8929019570350647), ('Precision@5', 0.885880708694458), ('Precision@10', 0.8833938837051392), ('Recall@1', 0.0030145926866680384), ('Recall@5', 0.015049256384372711), ('Recall@10', 0.030261926352977753), ('NDCG@5', 0.886837363243103), ('NDCG@10', 0.8847842812538147), ('NDCG@20', 0.8847416043281555)])
2 stage (test) OrderedDict([('Precision@1', 0.9711562395095825), ('Precision@5', 0.9493853449821472), ('Precision@10', 0.9385502338409424), ('Recall@1', 0.007329127751290798), ('Recall@5', 0.024591417983174324), ('Recall@10', 0.04251473769545555), ('NDCG@5', 0.9547252058982849), ('NDCG@10', 0.9468643069267273), ('NDCG@20', 0.9290348887443542)])
---epoch 11---
1 stage 0.8360927104949951
2 stage 0.939238429069519
1 stage (val) OrderedDict([('Precision@1', 0.7683557271957397), ('Precision@5', 0.7737331390380859), ('Precision@10', 0.7717682719230652), ('Recall@1', 0.0025312236975878477), ('Recall@5', 0.013931073248386383), ('Recall@10', 0.02832760475575924), ('NDCG@5', 0.7724922895431519), ('NDCG@10', 0.7718040943145752), ('NDCG@20', 0.772021472454071)])
2 stage (val) OrderedDict([('Precision@1', 0.9317476749420166), ('Precision@5', 0.8892449736595154), ('Precision@10', 0.8650463223457336), ('Recall@1', 0.01362884882837534), ('Recall@5', 0.03550821170210838), ('Recall@10', 0.05416148528456688), ('NDCG@5', 0.900104284286499), ('NDCG@10', 0.8835929036140442), ('NDCG@20', 0.8518088459968567)])
1 stage (test) OrderedDict([('Precision@1', 0.8939051628112793), ('Precision@5', 0.8860311508178711), ('Precision@10', 0.8833186030387878), ('Recall@1', 0.003023584373295307), ('Recall@5', 0.015060748904943466), ('Recall@10', 0.030287474393844604), ('NDCG@5', 0.8870344161987305), ('NDCG@10', 0.8848085403442383), ('NDCG@20', 0.8842055201530457)])
2 stage (test) OrderedDict([('Precision@1', 0.9709054231643677), ('Precision@5', 0.949586033821106), ('Precision@10', 0.9382994174957275), ('Recall@1', 0.0073165870271623135), ('Recall@5', 0.02467793971300125), ('Recall@10', 0.04251282662153244), ('NDCG@5', 0.9548905491828918), ('NDCG@10', 0.9467594623565674), ('NDCG@20', 0.9288665652275085)])
---epoch 12---
1 stage 0.8357616066932678
2 stage 0.9385761618614197
1 stage (val) OrderedDict([('Precision@1', 0.7688727974891663), ('Precision@5', 0.7738365530967712), ('Precision@10', 0.7722853422164917), ('Recall@1', 0.002532533835619688), ('Recall@5', 0.01389816403388977), ('Recall@10', 0.028517844155430794), ('NDCG@5', 0.7726624011993408), ('NDCG@10', 0.7721626162528992), ('NDCG@20', 0.7710179686546326)])
2 stage (val) OrderedDict([('Precision@1', 0.9312306046485901), ('Precision@5', 0.8893485069274902), ('Precision@10', 0.8653566241264343), ('Recall@1', 0.013489209115505219), ('Recall@5', 0.03542754054069519), ('Recall@10', 0.05412572622299194), ('NDCG@5', 0.9000642895698547), ('NDCG@10', 0.8836491703987122), ('NDCG@20', 0.8513972163200378)])
1 stage (test) OrderedDict([('Precision@1', 0.8931527733802795), ('Precision@5', 0.8858307003974915), ('Precision@10', 0.8832935690879822), ('Recall@1', 0.003019903087988496), ('Recall@5', 0.015028968453407288), ('Recall@10', 0.030282167717814445), ('NDCG@5', 0.886849045753479), ('NDCG@10', 0.8847469687461853), ('NDCG@20', 0.8832032680511475)])
2 stage (test) OrderedDict([('Precision@1', 0.9701529741287231), ('Precision@5', 0.9494857788085938), ('Precision@10', 0.9382240772247314), ('Recall@1', 0.007107383105903864), ('Recall@5', 0.024620026350021362), ('Recall@10', 0.04242135211825371), ('NDCG@5', 0.9546770453453064), ('NDCG@10', 0.9465430974960327), ('NDCG@20', 0.9285589456558228)])
---epoch 13---
1 stage 0.8362582921981812
2 stage 0.937748372554779
1 stage (val) OrderedDict([('Precision@1', 0.7709410786628723), ('Precision@5', 0.7739399671554565), ('Precision@10', 0.7720269560813904), ('Recall@1', 0.0025725809391587973), ('Recall@5', 0.013920770026743412), ('Recall@10', 0.028608163818717003), ('NDCG@5', 0.77310711145401), ('NDCG@10', 0.7722665071487427), ('NDCG@20', 0.7705519795417786)])
2 stage (val) OrderedDict([('Precision@1', 0.9296793937683105), ('Precision@5', 0.8893485069274902), ('Precision@10', 0.8653048276901245), ('Recall@1', 0.013278146274387836), ('Recall@5', 0.03546256572008133), ('Recall@10', 0.054152097553014755), ('NDCG@5', 0.8998194932937622), ('NDCG@10', 0.8834453225135803), ('NDCG@20', 0.8509284257888794)])
1 stage (test) OrderedDict([('Precision@1', 0.8929019570350647), ('Precision@5', 0.8853790760040283), ('Precision@10', 0.8832684755325317), ('Recall@1', 0.0030202772468328476), ('Recall@5', 0.015010001137852669), ('Recall@10', 0.030313612893223763), ('NDCG@5', 0.8865785002708435), ('NDCG@10', 0.8847340941429138), ('NDCG@20', 0.8825224041938782)])
2 stage (test) OrderedDict([('Precision@1', 0.9696513414382935), ('Precision@5', 0.9491847157478333), ('Precision@10', 0.9378730654716492), ('Recall@1', 0.007022056728601456), ('Recall@5', 0.024508126080036163), ('Recall@10', 0.0422583743929863), ('NDCG@5', 0.9543265700340271), ('NDCG@10', 0.9461403489112854), ('NDCG@20', 0.9281177520751953)])
---epoch 14---
1 stage 0.8352649211883545
2 stage 0.937748372554779
1 stage (val) OrderedDict([('Precision@1', 0.7699069380760193), ('Precision@5', 0.7740433216094971), ('Precision@10', 0.7716652154922485), ('Recall@1', 0.0025672533083707094), ('Recall@5', 0.013907205313444138), ('Recall@10', 0.028547627851366997), ('NDCG@5', 0.7730266451835632), ('NDCG@10', 0.7719103097915649), ('NDCG@20', 0.7697432637214661)])
2 stage (val) OrderedDict([('Precision@1', 0.9286453127861023), ('Precision@5', 0.8890382647514343), ('Precision@10', 0.8652012944221497), ('Recall@1', 0.012992068193852901), ('Recall@5', 0.03513404354453087), ('Recall@10', 0.05389030650258064), ('NDCG@5', 0.8995911478996277), ('NDCG@10', 0.8832032084465027), ('NDCG@20', 0.8503848314285278)])
1 stage (test) OrderedDict([('Precision@1', 0.8918986916542053), ('Precision@5', 0.8851282596588135), ('Precision@10', 0.8826413154602051), ('Recall@1', 0.003015252063050866), ('Recall@5', 0.015008385293185711), ('Recall@10', 0.030310826376080513), ('NDCG@5', 0.8863389492034912), ('NDCG@10', 0.8842962384223938), ('NDCG@20', 0.8815882802009583)])
2 stage (test) OrderedDict([('Precision@1', 0.9701529741287231), ('Precision@5', 0.9492349028587341), ('Precision@10', 0.9378228783607483), ('Recall@1', 0.006919103674590588), ('Recall@5', 0.024281606078147888), ('Recall@10', 0.04200495406985283), ('NDCG@5', 0.9543092846870422), ('NDCG@10', 0.9460137486457825), ('NDCG@20', 0.9279131889343262)])
---epoch 15---
1 stage 0.8342715501785278
2 stage 0.9382450580596924
1 stage (val) OrderedDict([('Precision@1', 0.7688727974891663), ('Precision@5', 0.7741466164588928), ('Precision@10', 0.7720270752906799), ('Recall@1', 0.0025641624815762043), ('Recall@5', 0.013893092051148415), ('Recall@10', 0.028611961752176285), ('NDCG@5', 0.7728968262672424), ('NDCG@10', 0.7719386219978333), ('NDCG@20', 0.7689577341079712)])
2 stage (val) OrderedDict([('Precision@1', 0.9301964640617371), ('Precision@5', 0.8893486857414246), ('Precision@10', 0.8648910522460938), ('Recall@1', 0.012969724833965302), ('Recall@5', 0.035027049481868744), ('Recall@10', 0.0536569282412529), ('NDCG@5', 0.8999989032745361), ('NDCG@10', 0.8830557465553284), ('NDCG@20', 0.8499181270599365)])
1 stage (test) OrderedDict([('Precision@1', 0.890895426273346), ('Precision@5', 0.885128378868103), ('Precision@10', 0.8826413154602051), ('Recall@1', 0.0030101861339062452), ('Recall@5', 0.014987428672611713), ('Recall@10', 0.030330732464790344), ('NDCG@5', 0.8862097859382629), ('NDCG@10', 0.8841873407363892), ('NDCG@20', 0.8804370760917664)])
2 stage (test) OrderedDict([('Precision@1', 0.9701529741287231), ('Precision@5', 0.9493853449821472), ('Precision@10', 0.9377727508544922), ('Recall@1', 0.006855727173388004), ('Recall@5', 0.0241836067289114), ('Recall@10', 0.04189404845237732), ('NDCG@5', 0.9543688297271729), ('NDCG@10', 0.9459018707275391), ('NDCG@20', 0.9275990128517151)])
---epoch 16---
1 stage 0.8342715501785278
2 stage 0.939238429069519
1 stage (val) OrderedDict([('Precision@1', 0.7699069380760193), ('Precision@5', 0.7728022933006287), ('Precision@10', 0.7725440263748169), ('Recall@1', 0.002587593160569668), ('Recall@5', 0.01379038393497467), ('Recall@10', 0.02874572202563286), ('NDCG@5', 0.7722145915031433), ('NDCG@10', 0.7722715139389038), ('NDCG@20', 0.7682958245277405)])
2 stage (val) OrderedDict([('Precision@1', 0.9322647452354431), ('Precision@5', 0.8898658752441406), ('Precision@10', 0.8647358417510986), ('Recall@1', 0.01313948817551136), ('Recall@5', 0.03528093919157982), ('Recall@10', 0.053822338581085205), ('NDCG@5', 0.9008933305740356), ('NDCG@10', 0.8833855390548706), ('NDCG@20', 0.8500461578369141)])
1 stage (test) OrderedDict([('Precision@1', 0.8903937935829163), ('Precision@5', 0.8849779367446899), ('Precision@10', 0.8826160430908203), ('Recall@1', 0.0030037390533834696), ('Recall@5', 0.014980170875787735), ('Recall@10', 0.03035588748753071), ('NDCG@5', 0.8861247897148132), ('NDCG@10', 0.8841509819030762), ('NDCG@20', 0.8795574903488159)])
2 stage (test) OrderedDict([('Precision@1', 0.9706546068191528), ('Precision@5', 0.9496863484382629), ('Precision@10', 0.9377727508544922), ('Recall@1', 0.006977349519729614), ('Recall@5', 0.024342140182852745), ('Recall@10', 0.0420152023434639), ('NDCG@5', 0.9546906352043152), ('NDCG@10', 0.9460608959197998), ('NDCG@20', 0.9275720715522766)])
---epoch 17---
1 stage 0.8339403867721558
2 stage 0.9387417435646057
1 stage (val) OrderedDict([('Precision@1', 0.7688727974891663), ('Precision@5', 0.7732157707214355), ('Precision@10', 0.7721819281578064), ('Recall@1', 0.0025853486731648445), ('Recall@5', 0.013845651410520077), ('Recall@10', 0.028628256171941757), ('NDCG@5', 0.7722429633140564), ('NDCG@10', 0.7718576192855835), ('NDCG@20', 0.767686665058136)])
2 stage (val) OrderedDict([('Precision@1', 0.9312306046485901), ('Precision@5', 0.8897624015808105), ('Precision@10', 0.864425778388977), ('Recall@1', 0.013013952411711216), ('Recall@5', 0.035018254071474075), ('Recall@10', 0.05353393778204918), ('NDCG@5', 0.9006258249282837), ('NDCG@10', 0.8829480409622192), ('NDCG@20', 0.8497097492218018)])
1 stage (test) OrderedDict([('Precision@1', 0.8903937935829163), ('Precision@5', 0.8849778175354004), ('Precision@10', 0.8823150396347046), ('Recall@1', 0.0030055013485252857), ('Recall@5', 0.014995132572948933), ('Recall@10', 0.030364053323864937), ('NDCG@5', 0.8861544132232666), ('NDCG@10', 0.8839383721351624), ('NDCG@20', 0.8786713480949402)])
2 stage (test) OrderedDict([('Precision@1', 0.970403790473938), ('Precision@5', 0.9499371647834778), ('Precision@10', 0.9376974701881409), ('Recall@1', 0.006946822162717581), ('Recall@5', 0.024365242570638657), ('Recall@10', 0.04201309755444527), ('NDCG@5', 0.9548534750938416), ('NDCG@10', 0.9460048079490662), ('NDCG@20', 0.9273778200149536)])
---epoch 18---
1 stage 0.8332781791687012
2 stage 0.9384106397628784
1 stage (val) OrderedDict([('Precision@1', 0.7683557271957397), ('Precision@5', 0.7734225392341614), ('Precision@10', 0.7714068293571472), ('Recall@1', 0.0025810804218053818), ('Recall@5', 0.013942412100732327), ('Recall@10', 0.02855859324336052), ('NDCG@5', 0.7723387479782104), ('NDCG@10', 0.7713356614112854), ('NDCG@20', 0.767216145992279)])
2 stage (val) OrderedDict([('Precision@1', 0.9301964640617371), ('Precision@5', 0.889141857624054), ('Precision@10', 0.8643222451210022), ('Recall@1', 0.01304357685148716), ('Recall@5', 0.034884266555309296), ('Recall@10', 0.053411055356264114), ('NDCG@5', 0.9000440239906311), ('NDCG@10', 0.8827939033508301), ('NDCG@20', 0.8493320941925049)])
1 stage (test) OrderedDict([('Precision@1', 0.8896413445472717), ('Precision@5', 0.8848274350166321), ('Precision@10', 0.8820641040802002), ('Recall@1', 0.0029982151463627815), ('Recall@5', 0.015002566389739513), ('Recall@10', 0.03033376671373844), ('NDCG@5', 0.8859151601791382), ('NDCG@10', 0.883664071559906), ('NDCG@20', 0.8777598142623901)])
2 stage (test) OrderedDict([('Precision@1', 0.970403790473938), ('Precision@5', 0.949836790561676), ('Precision@10', 0.9376975297927856), ('Recall@1', 0.006989224348217249), ('Recall@5', 0.024378934875130653), ('Recall@10', 0.042055051773786545), ('NDCG@5', 0.9547573924064636), ('NDCG@10', 0.946016252040863), ('NDCG@20', 0.9272603392601013)])
---epoch 19---
1 stage 0.8344370722770691
2 stage 0.937417209148407
1 stage (val) OrderedDict([('Precision@1', 0.7724922299385071), ('Precision@5', 0.7734224796295166), ('Precision@10', 0.7712516784667969), ('Recall@1', 0.0026290433015674353), ('Recall@5', 0.013950219377875328), ('Recall@10', 0.028614703565835953), ('NDCG@5', 0.7729315161705017), ('NDCG@10', 0.7716161608695984), ('NDCG@20', 0.7671958208084106)])
2 stage (val) OrderedDict([('Precision@1', 0.9281282424926758), ('Precision@5', 0.8886248469352722), ('Precision@10', 0.8638569712638855), ('Recall@1', 0.01269526220858097), ('Recall@5', 0.03449012339115143), ('Recall@10', 0.05301033332943916), ('NDCG@5', 0.8991745114326477), ('NDCG@10', 0.8819308876991272), ('NDCG@20', 0.8481321334838867)])
1 stage (test) OrderedDict([('Precision@1', 0.8893905282020569), ('Precision@5', 0.8843759894371033), ('Precision@10', 0.8814369440078735), ('Recall@1', 0.0030090906657278538), ('Recall@5', 0.014989291317760944), ('Recall@10', 0.030279411002993584), ('NDCG@5', 0.8855165243148804), ('NDCG@10', 0.8831601738929749), ('NDCG@20', 0.8770545125007629)])
2 stage (test) OrderedDict([('Precision@1', 0.9699021577835083), ('Precision@5', 0.9497867226600647), ('Precision@10', 0.9374968409538269), ('Recall@1', 0.006969019304960966), ('Recall@5', 0.024379633367061615), ('Recall@10', 0.04202014207839966), ('NDCG@5', 0.9546409845352173), ('NDCG@10', 0.9458295702934265), ('NDCG@20', 0.9270906448364258)])
print("Best validation epoch: {}".format(best_epoch))
print("Best validation stage results\n 1 stage: {}\n 2 stage: {}".format(
    val_results[best_epoch][0], val_results[best_epoch][1]))
print("Best test results\n 1 stage: {}\n 2 stage: {}".format(
    test_results[best_epoch][0], test_results[best_epoch][1]))
Best validation epoch: 6
Best validation stage results
 1 stage: OrderedDict([('Precision@1', 0.7652533650398254), ('Precision@5', 0.7777659296989441), ('Precision@10', 0.7713546752929688), ('Recall@1', 0.002462433883920312), ('Recall@5', 0.014112057164311409), ('Recall@10', 0.02813573367893696), ('NDCG@5', 0.7751555442810059), ('NDCG@10', 0.7717844247817993), ('NDCG@20', 0.7759582996368408)])
 2 stage: OrderedDict([('Precision@1', 0.9353671073913574), ('Precision@5', 0.8887278437614441), ('Precision@10', 0.866390585899353), ('Recall@1', 0.013995282351970673), ('Recall@5', 0.035590916872024536), ('Recall@10', 0.05474210903048515), ('NDCG@5', 0.9002852439880371), ('NDCG@10', 0.8846437931060791), ('NDCG@20', 0.8533338904380798)])
Best test results
 1 stage: OrderedDict([('Precision@1', 0.8931527733802795), ('Precision@5', 0.8867837190628052), ('Precision@10', 0.8830179572105408), ('Recall@1', 0.0030013115610927343), ('Recall@5', 0.01514824852347374), ('Recall@10', 0.03022931143641472), ('NDCG@5', 0.8876431584358215), ('NDCG@10', 0.8846896886825562), ('NDCG@20', 0.886621356010437)])
 2 stage: OrderedDict([('Precision@1', 0.9729119539260864), ('Precision@5', 0.9496863484382629), ('Precision@10', 0.9388262033462524), ('Recall@1', 0.007614111993461847), ('Recall@5', 0.024715635925531387), ('Recall@10', 0.04267028719186783), ('NDCG@5', 0.9553412199020386), ('NDCG@10', 0.9473593235015869), ('NDCG@20', 0.9304744601249695)])
pickle.dump((best_epoch, val_results, test_results),
            open("{}-a{}_{}.pkl".format(
                args.loss_type.split("_")[1], args.alpha, args.seed), "wb"))
!apt-get install tree
!tree --du -h
.
├── [ 20K]  ce-a0.001_0.pkl
├── [651M]  ml-1m
│   ├── [626M]  full_impression_feats.pt
│   ├── [463K]  logging_policy.pt
│   ├── [167K]  movies.dat
│   ├── [463K]  ranker.pt
│   ├── [ 23M]  ratings.dat
│   ├── [5.4K]  README
│   ├── [502K]  simulator.pt
│   └── [131K]  users.dat
├── [5.6M]  ml-1m.zip
└── [ 54M]  sample_data
    ├── [1.7K]  anscombe.json
    ├── [294K]  california_housing_test.csv
    ├── [1.6M]  california_housing_train.csv
    ├── [ 17M]  mnist_test.csv
    ├── [ 35M]  mnist_train_small.csv
    └── [ 930]  README.md

 711M used in 2 directories, 16 files

Experimental results

Conclusion

  • The results of 2-IPS method in both evaluations are better than 1-IPS and cross-entropy.

  • 1-IPS performs better than the cross-entropy method in one-stage evaluation, and performs worse than the cross-entropy method in two-stage evaluation, indicating that only improving the performance of a part of the system may not necessarily improve the performance of the entire system.