Top-K Off-Policy Correction for a REINFORCE Recommender System

CLI run

!gdown --id 1erBjYEOa7IuOIGpI8pGPn1WNBAC4Rv0-
!git clone https://github.com/massquantity/DBRL.git
!unzip /content/ECommAI_EUIR_round2_train_20190821.zip
!mv ECommAI_EUIR_round2_train_20190816/*.csv DBRL/dbrl/resources
%cd DBRL/dbrl
Downloading...
From: https://drive.google.com/uc?id=1erBjYEOa7IuOIGpI8pGPn1WNBAC4Rv0-
To: /content/ECommAI_EUIR_round2_train_20190821.zip
100% 894M/894M [00:06<00:00, 146MB/s]
Cloning into 'DBRL'...
remote: Enumerating objects: 118, done.
remote: Counting objects: 100% (118/118), done.
remote: Compressing objects: 100% (83/83), done.
remote: Total 118 (delta 29), reused 114 (delta 25), pack-reused 0
Receiving objects: 100% (118/118), 203.89 KiB | 2.87 MiB/s, done.
Resolving deltas: 100% (29/29), done.
Archive:  /content/ECommAI_EUIR_round2_train_20190821.zip
   creating: ECommAI_EUIR_round2_train_20190816/
  inflating: ECommAI_EUIR_round2_train_20190816/user_behavior.csv  
  inflating: ECommAI_EUIR_round2_train_20190816/item.csv  
  inflating: ECommAI_EUIR_round2_train_20190816/user.csv  
/content/DBRL/dbrl
!python run_prepare_data.py
{'seed': 0}
tcmalloc: large alloc 1931411456 bytes == 0x559fe1e9e000 @  0x7fbeb7c811e7 0x7fbeb580146e 0x7fbeb5851c7b 0x7fbeb585235f 0x7fbeb58f4103 0x559fd33e4544 0x559fd33e4240 0x559fd3458627 0x559fd33e5afa 0x559fd3453915 0x559fd34529ee 0x559fd33e5bda 0x559fd3453915 0x559fd33e5afa 0x559fd3453915 0x559fd33e5afa 0x559fd3453915 0x559fd34529ee 0x559fd33e5bda 0x559fd3454737 0x559fd3452ced 0x559fd33e5bda 0x559fd3454737 0x559fd3452ced 0x559fd33e648c 0x559fd3427159 0x559fd34240a4 0x559fd33e4d49 0x559fd345894f 0x559fd34529ee 0x559fd33e5bda
tcmalloc: large alloc 1931411456 bytes == 0x55a1625c8000 @  0x7fbeb7c811e7 0x7fbeb580146e 0x7fbeb5851c7b 0x7fbeb585235f 0x7fbeb58f4103 0x559fd33e4544 0x559fd33e4240 0x559fd3458627 0x559fd34529ee 0x559fd33e5bda 0x559fd3454737 0x559fd34529ee 0x559fd33e5bda 0x559fd3454737 0x559fd3452ced 0x559fd33e5bda 0x559fd3453915 0x559fd3452ced 0x559fd33e5bda 0x559fd3454737 0x559fd34529ee 0x559fd33e5bda 0x559fd3454737 0x559fd3452ced 0x559fd33e5bda 0x559fd3454737 0x559fd34529ee 0x559fd34526f3 0x559fd351c4c2 0x559fd351c83d 0x559fd351c6e6
tcmalloc: large alloc 1931411456 bytes == 0x55a1d57b8000 @  0x7fbeb7c811e7 0x7fbeb580146e 0x7fbeb5851c7b 0x7fbeb5851d97 0x7fbeb584b4a5 0x7fbeb58e8eab 0x559fd33e44b0 0x559fd34d5e1d 0x559fd3457e99 0x559fd34529ee 0x559fd33e648c 0x559fd33e6698 0x559fd3454fe4 0x559fd3452ced 0x559fd33e5bda 0x559fd3454737 0x559fd3452ced 0x559fd33e5bda 0x559fd3454737 0x559fd34529ee 0x559fd33e5bda 0x559fd3457d00 0x559fd3452ced 0x559fd33e5bda 0x559fd3454737 0x559fd34529ee 0x559fd34526f3 0x559fd351c4c2 0x559fd351c83d 0x559fd351c6e6 0x559fd34f4163
tcmalloc: large alloc 1548664832 bytes == 0x559fe1e9e000 @  0x7fbeb7c811e7 0x7fbeb580146e 0x7fbeb5851c7b 0x7fbeb585235f 0x7fbeb58f4103 0x559fd33e4544 0x559fd33e4240 0x559fd3458627 0x559fd34529ee 0x559fd33e5bda 0x559fd3454737 0x559fd34529ee 0x559fd33e5bda 0x559fd3454737 0x559fd3452ced 0x559fd33e5bda 0x559fd3453915 0x559fd3452ced 0x559fd33e5bda 0x559fd3454737 0x559fd34529ee 0x559fd33e5bda 0x559fd3454737 0x559fd34529ee 0x559fd33e5bda 0x559fd3454737 0x559fd34529ee 0x559fd33e5bda 0x559fd3454737 0x559fd33e5afa 0x559fd3457d00
n_users: 80000, n_items: 1047166, behavior length: 3234367
prepare data done!, time elapsed: 173.06
!python run_pretrain_embeddings.py --lr 0.001 --n_epochs 4
A list all args: 
======================
{'batch_size': 2048,
 'data': 'tianchi.csv',
 'embed_size': 32,
 'loss': 'cosine',
 'lr': 0.001,
 'n_epochs': 4,
 'neg_item': 1,
 'seed': 0}

n_users: 80000, n_items: 912114, train_shape: (2587518, 10), eval_shape: (646849, 10)
100% 2527/2527 [02:22<00:00, 17.73it/s]
100% 632/632 [00:11<00:00, 56.31it/s]
epoch 1, train_loss: 0.3253, eval loss: 0.3537, eval roc: 0.7370
100% 2527/2527 [02:21<00:00, 17.85it/s]
100% 632/632 [00:11<00:00, 55.87it/s]
epoch 2, train_loss: 0.2568, eval loss: 0.3351, eval roc: 0.7697
100% 2527/2527 [02:21<00:00, 17.84it/s]
100% 632/632 [00:11<00:00, 56.34it/s]
epoch 3, train_loss: 0.2260, eval loss: 0.3309, eval roc: 0.7772
100% 2527/2527 [02:20<00:00, 17.96it/s]
100% 632/632 [00:11<00:00, 57.03it/s]
epoch 4, train_loss: 0.2036, eval loss: 0.3296, eval roc: 0.7829
user_embeds shape: (80000, 32), item_embeds shape: (912115, 32)
pretrain embeddings done!
!python run_reinforce.py --n_epochs 1 --lr 1e-5
A list all args: 
======================
{'batch_size': 128,
 'data': 'tianchi.csv',
 'gamma': 0.99,
 'hidden_size': 64,
 'hist_num': 10,
 'item_embeds': 'tianchi_item_embeddings.npy',
 'lr': 1e-05,
 'n_epochs': 1,
 'n_rec': 10,
 'seed': 0,
 'sess_mode': 'interval',
 'user_embeds': 'tianchi_user_embeddings.npy',
 'weight_decay': 0.0}

Number of parameters: policy: 118628454,  beta: 59310067
Caution: Will compute loss every 10 step(s)

Epoch 1 start-time: 2021-10-21 10:46:33

train: 100% 19590/19590 [2:24:05<00:00,  2.27it/s]
last_eval: 100% 625/625 [00:47<00:00, 13.12it/s]

policy_loss: 665.2355, beta_loss: 13.6349, importance_weight: 0.8856, lambda_k: 9.9993, 
reward: 455, ndcg_next_item: 0.000999, ndcg_all_item: 0.039790, ndcg: 0.027800

******************** EVAL ********************
eval: 100% 10516/10516 [20:26<00:00,  8.58it/s]
last_eval: 100% 625/625 [00:47<00:00, 13.12it/s]

policy_loss: 1333.3558, beta_loss: 13.5444, importance_weight: 0.9655, lambda_k: 9.9992, 
reward: 290, ndcg_next_item: 0.001165, ndcg_all_item: 0.008780, ndcg: 0.007617
================================================================================
train and save done!
!apt-get -qq install tree
!tree --du -h .
Selecting previously unselected package tree.
(Reading database ... 155047 files and directories currently installed.)
Preparing to unpack .../tree_1.7.0-5_amd64.deb ...
Unpacking tree (1.7.0-5) ...
Setting up tree (1.7.0-5) ...
Processing triggers for man-db (2.8.3-2ubuntu0.1) ...
.
├── [ 56K]  data
│   ├── [7.7K]  dataset.py
│   ├── [ 126]  __init__.py
│   ├── [8.4K]  process.py
│   ├── [ 24K]  __pycache__
│   │   ├── [5.5K]  dataset.cpython-37.pyc
│   │   ├── [ 284]  __init__.cpython-37.pyc
│   │   ├── [5.6K]  process.cpython-37.pyc
│   │   ├── [6.3K]  session.cpython-37.pyc
│   │   └── [2.2K]  split.cpython-37.pyc
│   ├── [9.6K]  session.py
│   └── [2.5K]  split.py
├── [ 17K]  evaluate
│   ├── [4.1K]  evaluate.py
│   ├── [  45]  __init__.py
│   ├── [1.6K]  metrics.py
│   └── [7.3K]  __pycache__
│       ├── [1.9K]  evaluate.cpython-37.pyc
│       ├── [ 178]  __init__.cpython-37.pyc
│       └── [1.2K]  metrics.cpython-37.pyc
├── [   0]  __init__.py
├── [ 44K]  models
│   ├── [7.6K]  bcq.py
│   ├── [4.2K]  ddpg.py
│   ├── [3.3K]  dssm.py
│   ├── [ 103]  __init__.py
│   ├── [ 19K]  __pycache__
│   │   ├── [5.1K]  bcq.cpython-37.pyc
│   │   ├── [3.3K]  ddpg.cpython-37.pyc
│   │   ├── [2.4K]  dssm.cpython-37.pyc
│   │   ├── [ 256]  __init__.cpython-37.pyc
│   │   └── [3.9K]  youtube_topk.cpython-37.pyc
│   └── [5.7K]  youtube_topk.py
├── [ 24K]  network
│   ├── [  20]  __init__.py
│   ├── [9.1K]  net.py
│   └── [ 11K]  __pycache__
│       ├── [ 134]  __init__.cpython-37.pyc
│       └── [7.2K]  net.cpython-37.pyc
├── [4.1K]  __pycache__
│   └── [ 106]  __init__.cpython-37.pyc
├── [3.0G]  resources
│   ├── [   0]  aa
│   ├── [114M]  item.csv
│   ├── [ 15M]  item_map.json
│   ├── [574M]  model_reinforce.pt
│   ├── [160M]  tianchi.csv
│   ├── [111M]  tianchi_item_embeddings.npy
│   ├── [9.8M]  tianchi_user_embeddings.npy
│   ├── [2.0G]  user_behavior.csv
│   ├── [ 19M]  user.csv
│   └── [1.2M]  user_map.json
├── [5.9K]  run_bcq.py
├── [5.1K]  run_ddpg.py
├── [2.5K]  run_prepare_data.py
├── [4.4K]  run_pretrain_embeddings.py
├── [5.1K]  run_reinforce.py
├── [9.9K]  serialization
│   ├── [  43]  __init__.py
│   ├── [5.0K]  __pycache__
│   │   ├── [ 182]  __init__.cpython-37.pyc
│   │   └── [ 845]  serialize.cpython-37.pyc
│   └── [ 889]  serialize.py
├── [ 18K]  trainer
│   ├── [  68]  __init__.py
│   ├── [1.9K]  pretrain.py
│   ├── [8.2K]  __pycache__
│   │   ├── [ 202]  __init__.cpython-37.pyc
│   │   ├── [1.5K]  pretrain.cpython-37.pyc
│   │   └── [2.5K]  train.cpython-37.pyc
│   └── [3.9K]  train.py
└── [ 21K]  utils
    ├── [1.7K]  info.py
    ├── [ 156]  __init__.py
    ├── [2.2K]  misc.py
    ├── [1.7K]  params.py
    ├── [ 10K]  __pycache__
    │   ├── [1.5K]  info.cpython-37.pyc
    │   ├── [ 325]  __init__.cpython-37.pyc
    │   ├── [2.0K]  misc.cpython-37.pyc
    │   ├── [1.5K]  params.cpython-37.pyc
    │   └── [ 920]  sampling.cpython-37.pyc
    └── [ 908]  sampling.py

 3.0G used in 16 directories, 67 files

Code analysis

Data preparation

import os
import sys
sys.path.append(os.pardir)
import warnings
warnings.filterwarnings("ignore")
import argparse
import time
import numpy as np
import pandas as pd


def parse_args():
    parser = argparse.ArgumentParser(description="run_prepare_data")
    parser.add_argument("--seed", type=int, default=0)
    return parser.parse_args(args={})


def bucket_age(age):
    if age < 30:
        return 1
    elif age < 40:
        return 2
    elif age < 50:
        return 3
    else:
        return 4


if __name__ == "__main__":
    args = parse_args()
    print(vars(args))
    np.random.seed(args.seed)
    start_time = time.perf_counter()

    # 1. loading the data into memory

    user_feat = pd.read_csv("resources/user.csv", header=None,
                            names=["user", "sex", "age", "pur_power"])
    item_feat = pd.read_csv("resources/item.csv", header=None,
                            names=["item", "category", "shop", "brand"])
    behavior = pd.read_csv("resources/user_behavior.csv", header=None,
                           names=["user", "item", "behavior", "time"])
    
    # 2. sorting values chronologically and dropping duplicate records

    behavior = behavior.sort_values(by="time").reset_index(drop=True)
    behavior = behavior.drop_duplicates(subset=["user", "item", "behavior"])

    # 3. Choosing 60K random users with short journey and 20K with long journey
    user_counts = behavior.groupby("user")[["user"]].count().rename(
        columns={"user": "count_user"}
    ).sort_values("count_user", ascending=False)

    short_users = np.array(
        user_counts[
            (user_counts.count_user > 5) & (user_counts.count_user <= 50)
        ].index
    )
    long_users = np.array(
        user_counts[
            (user_counts.count_user > 50) & (user_counts.count_user <= 200)
        ].index
    )
    short_chosen_users = np.random.choice(short_users, 60000, replace=False)
    long_chosen_users = np.random.choice(long_users, 20000, replace=False)
    chosen_users = np.concatenate([short_chosen_users, long_chosen_users])

    behavior = behavior[behavior.user.isin(chosen_users)]
    print(f"n_users: {behavior.user.nunique()}, "
          f"n_items: {behavior.item.nunique()}, "
          f"behavior length: {len(behavior)}")

    # 4. merge with all features, bucketizing the age and saving the processed data
    behavior = behavior.merge(user_feat, on="user")
    behavior = behavior.merge(item_feat, on="item")
    behavior["age"] = behavior["age"].apply(bucket_age)
    behavior = behavior.sort_values(by="time").reset_index(drop=True)
    behavior.to_csv("resources/tianchi.csv", header=None, index=False)
    print(f"prepare data done!, "
          f"time elapsed: {(time.perf_counter() - start_time):.2f}")

Embeddings

import os
import sys
sys.path.append(os.pardir)
import warnings
warnings.filterwarnings("ignore")
import argparse
from pprint import pprint
import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from dbrl.data import process_feat_data, FeatDataset
from dbrl.models import DSSM
from dbrl.utils import sample_items_random, init_param_dssm, generate_embeddings
from dbrl.trainer import pretrain_model
from dbrl.serialization import save_npy, save_json


def parse_args():
    parser = argparse.ArgumentParser(description="run_pretrain_embeddings")
    parser.add_argument("--data", type=str, default="tianchi.csv")
    parser.add_argument("--n_epochs", type=int, default=10)
    parser.add_argument("--batch_size", type=int, default=2048)
    parser.add_argument("--lr", type=float, default=5e-4)
    parser.add_argument("--embed_size", type=int, default=32)
    parser.add_argument("--loss", type=str, default="cosine",
                        help="cosine or bce loss")
    parser.add_argument("--neg_item", type=int, default=1)
    parser.add_argument("--seed", type=int, default=0)
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()
    print("A list all args: \n======================")
    pprint(vars(args))
    print()

    # 1. Setting arguments/params

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    PATH = os.path.join("resources", args.data)
    EMBEDDING_PATH = "resources/"
    static_feat = ["sex", "age", "pur_power"]
    dynamic_feat = ["category", "shop", "brand"]
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_epochs = args.n_epochs
    batch_size = args.batch_size
    lr = args.lr
    item_embed_size = args.embed_size
    feat_embed_size = args.embed_size
    hidden_size = (256, 128)
    criterion = (
        nn.CosineEmbeddingLoss()
        if args.loss == "cosine"
        else nn.BCEWithLogitsLoss()
    )
    criterion_type = (
        "cosine"
        if "cosine" in criterion.__class__.__name__.lower()
        else "bce"
    )
    neg_label = -1. if criterion_type == "cosine" else 0.
    neg_item = args.neg_item

    # 2. Preprocessing

    columns = ["user", "item", "label", "time", "sex", "age", "pur_power",
               "category", "shop", "brand"]

    (
        n_users,
        n_items,
        train_user_consumed,
        eval_user_consumed,
        train_data,
        eval_data,
        user_map,
        item_map,
        feat_map
    ) = process_feat_data(
        PATH, columns, test_size=0.2, time_col="time",
        static_feat=static_feat, dynamic_feat=dynamic_feat
    )
    print(f"n_users: {n_users}, n_items: {n_items}, "
          f"train_shape: {train_data.shape}, eval_shape: {eval_data.shape}")
    
    # 3. Random negative sampling

    train_user, train_item, train_label = sample_items_random(
        train_data, n_items, train_user_consumed, neg_label, neg_item
    )
    eval_user, eval_item, eval_label = sample_items_random(
        eval_data, n_items, eval_user_consumed, neg_label, neg_item
    )

    # 4. Putting data into torch dataset format and dataloader

    train_dataset = FeatDataset(
        train_user,
        train_item,
        train_label,
        feat_map,
        static_feat,
        dynamic_feat
    )
    eval_dataset = FeatDataset(
        eval_user,
        eval_item,
        eval_label,
        feat_map,
        static_feat,
        dynamic_feat
    )
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size,
                              shuffle=True, num_workers=0)
    eval_loader = DataLoader(dataset=eval_dataset, batch_size=batch_size,
                             shuffle=False, num_workers=0)

    # 5. DSSM embedding model training

    model = DSSM(
        item_embed_size,
        feat_embed_size,
        n_users,
        n_items,
        hidden_size,
        feat_map,
        static_feat,
        dynamic_feat,
        use_bn=True
    ).to(device)
    init_param_dssm(model)
    optimizer = Adam(model.parameters(), lr=lr)  # weight_decay

    pretrain_model(model, train_loader, eval_loader, n_epochs, criterion,
                   criterion_type, optimizer, device)
    
    # 6. Generate and save embeddings
    
    user_embeddings, item_embeddings = generate_embeddings(
        model, n_users, n_items, feat_map, static_feat, dynamic_feat, device
    )
    print(f"user_embeds shape: {user_embeddings.shape},"
          f" item_embeds shape: {item_embeddings.shape}")

    save_npy(user_embeddings, item_embeddings, EMBEDDING_PATH)
    save_json(
        user_map, item_map, user_embeddings, item_embeddings, EMBEDDING_PATH
    )
    print("pretrain embeddings done!")

REINFORCE model

import torch
import torch.nn as nn
from torch.distributions import Categorical


class Reinforce(nn.Module):
    def __init__(
            self,
            policy,
            policy_optim,
            beta,
            beta_optim,
            hidden_size,
            gamma=0.99,
            k=10,
            weight_clip=2.0,
            offpolicy_correction=True,
            topk=True,
            adaptive_softmax=True,
            cutoffs=None,
            device=torch.device("cpu"),
    ):
        super(Reinforce, self).__init__()
        self.policy = policy
        self.policy_optim = policy_optim
        self.beta = beta
        self.beta_optim = beta_optim
        self.beta_criterion = nn.CrossEntropyLoss()
        self.gamma = gamma
        self.k = k
        self.weight_clip = weight_clip
        self.offpolicy_correction = offpolicy_correction
        self.topk = topk
        self.adaptive_softmax = adaptive_softmax
        if adaptive_softmax:
            assert cutoffs is not None, (
                "must provide cutoffs when using adaptive_softmax"
            )
            self.softmax_loss = nn.AdaptiveLogSoftmaxWithLoss(
                in_features=hidden_size,
                n_classes=policy.item_embeds.weight.size(0),
                cutoffs=cutoffs,
                div_value=4.
            ).to(device)
        self.device = device

    def update(self, data):
        (
            policy_loss,
            beta_loss,
            action,
            importance_weight,
            lambda_k
        ) = self._compute_loss(data)

        self.policy_optim.zero_grad()
        policy_loss.backward()
        self.policy_optim.step()

        self.beta_optim.zero_grad()
        beta_loss.backward()
        self.beta_optim.step()

        info = {'policy_loss': policy_loss.cpu().detach().item(),
                'beta_loss': beta_loss.cpu().detach().item(),
                'importance_weight': importance_weight.cpu().mean().item(),
                'lambda_k': lambda_k.cpu().mean().item(),
                'action': action}
        return info

    def _compute_weight(self, policy_logp, beta_logp):
        if self.offpolicy_correction:
            importance_weight = torch.exp(policy_logp - beta_logp).detach()
            wc = torch.tensor([self.weight_clip]).to(self.device)
            importance_weight = torch.min(importance_weight, wc)
        #    importance_weight = torch.clamp(
        #        importance_weight, self.weight_clip[0], self.weight_clip[1]
        #    )
        else:
            importance_weight = torch.tensor([1.]).float().to(self.device)
        return importance_weight

    def _compute_lambda_k(self, policy_logp):
        lam = (
            self.k * ((1. - policy_logp.exp()).pow(self.k - 1)).detach()
            if self.topk
            else torch.tensor([1.]).float().to(self.device)
        )
        return lam

    def _compute_loss(self, data):
        if self.adaptive_softmax:
            state, action = self.policy(data)
            policy_out = self.softmax_loss(action, data["action"])
            policy_logp = policy_out.output

            beta_action = self.beta(state.detach())
            beta_out = self.softmax_loss(beta_action, data["action"])
            beta_logp = beta_out.output
        else:
            state, all_logp, action = self.policy.get_log_probs(data)
            policy_logp = all_logp[:, data["action"]]

            b_logp, beta_logits = self.beta.get_log_probs(state.detach())
            beta_logp = (b_logp[:, data["action"]]).detach()

        importance_weight = self._compute_weight(policy_logp, beta_logp)
        lambda_k = self._compute_lambda_k(policy_logp)

        policy_loss = -(
                importance_weight * lambda_k * data["return"] * policy_logp
        ).mean()

        if self.adaptive_softmax:
            if "beta_label" in data:
                b_state = self.policy.get_beta_state(data)
                b_action = self.beta(b_state.detach())
                b_out = self.softmax_loss(b_action, data["beta_label"])
                beta_loss = b_out.loss
            else:
                beta_loss = beta_out.loss
        else:
            if "beta_label" in data:
                b_state = self.policy.get_beta_state(data)
                _, b_logits = self.beta.get_log_probs(b_state.detach())
                beta_loss = self.beta_criterion(b_logits, data["beta_label"])
            else:
                beta_loss = self.beta_criterion(beta_logits, data["action"])
        return policy_loss, beta_loss, action, importance_weight, lambda_k

    def compute_loss(self, data):
        (
            policy_loss,
            beta_loss,
            action,
            importance_weight,
            lambda_k
        ) = self._compute_loss(data)

        info = {'policy_loss': policy_loss.cpu().detach().item(),
                'beta_loss': beta_loss.cpu().detach().item(),
                'importance_weight': importance_weight.cpu().mean().item(),
                'lambda_k': lambda_k.cpu().mean().item(),
                'action': action}
        return info

    def get_log_probs(self, data=None, action=None):
        with torch.no_grad():
            if self.adaptive_softmax:
                if action is None:
                    _, action = self.policy.forward(data)
                log_probs = self.softmax_loss.log_prob(action)
            else:
            #    _, log_probs = self.policy.get_log_probs(data)
                if action is None:
                    _, action = self.policy.forward(data)
                log_probs = self.policy.softmax_fc(action)
        return log_probs

    def forward(self, state):
        policy_logits = self.policy.get_action(state)
        policy_dist = Categorical(logits=policy_logits)
        _, rec_idxs = torch.topk(policy_dist.probs, 10, dim=1)
        return rec_idxs

Trainer

import os
import sys
sys.path.append(os.pardir)
import warnings
warnings.filterwarnings("ignore")
import argparse
from pprint import pprint
import numpy as np
import torch
from torch.optim import Adam
from dbrl.data import process_data, build_dataloader
from dbrl.models import Reinforce
from dbrl.network import PolicyPi, Beta
from dbrl.trainer import train_model
from dbrl.utils import count_vars, init_param


def parse_args():
    parser = argparse.ArgumentParser(description="run_reinforce")
    parser.add_argument("--data", type=str, default="tianchi.csv")
    parser.add_argument("--user_embeds", type=str,
                        default="tianchi_user_embeddings.npy")
    parser.add_argument("--item_embeds", type=str,
                        default="tianchi_item_embeddings.npy")
    parser.add_argument("--n_epochs", type=int, default=100)
    parser.add_argument("--hist_num", type=int, default=10,
                        help="num of history items to consider")
    parser.add_argument("--n_rec", type=int, default=10,
                        help="num of items to recommend")
    parser.add_argument("--batch_size", type=int, default=128)
    parser.add_argument("--hidden_size", type=int, default=64)
    parser.add_argument("--lr", type=float, default=1e-5)
    parser.add_argument("--weight_decay", type=float, default=0.)
    parser.add_argument("--gamma", type=float, default=0.99)
    parser.add_argument("--sess_mode", type=str, default="interval",
                        help="Specify when to end a session")
    parser.add_argument("--seed", type=int, default=0)
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()
    print("A list all args: \n======================")
    pprint(vars(args))
    print()

    # 1. Loading user and item embeddings

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    PATH = os.path.join("resources", args.data)
    with open(os.path.join("resources", args.user_embeds), "rb") as f:
        user_embeddings = np.load(f)
    with open(os.path.join("resources", args.item_embeds), "rb") as f:
        item_embeddings = np.load(f)
    item_embeddings[-1] = 0.   # last item is used for padding

    # 2. Setting model arguments/params

    n_epochs = args.n_epochs
    hist_num = args.hist_num
    batch_size = eval_batch_size = args.batch_size
    embed_size = item_embeddings.shape[1]
    hidden_size = args.hidden_size
    input_dim = embed_size * (hist_num + 1)
    action_dim = len(item_embeddings)
    policy_lr = args.lr
    beta_lr = args.lr
    weight_decay = args.weight_decay
    gamma = args.gamma
    n_rec = args.n_rec
    pad_val = len(item_embeddings) - 1
    sess_mode = args.sess_mode
    debug = True
    one_hour = int(60 * 60)
    reward_map = {"pv": 1., "cart": 2., "fav": 2., "buy": 3.}
    columns = ["user", "item", "label", "time", "sex", "age", "pur_power",
               "category", "shop", "brand"]

    cutoffs = [
        len(item_embeddings) // 20,
        len(item_embeddings) // 10,
        len(item_embeddings) // 3
    ]

    # 3. Building the data loader

    (
        n_users,
        n_items,
        train_user_consumed,
        test_user_consumed,
        train_sess_end,
        test_sess_end,
        train_rewards,
        test_rewards
    ) = process_data(PATH, columns, 0.2, time_col="time", sess_mode=sess_mode,
                     interval=one_hour, reward_shape=reward_map)

    train_loader, eval_loader = build_dataloader(
        n_users,
        n_items,
        hist_num,
        train_user_consumed,
        test_user_consumed,
        batch_size,
        sess_mode=sess_mode,
        train_sess_end=train_sess_end,
        test_sess_end=test_sess_end,
        n_workers=0,
        compute_return=True,
        neg_sample=False,
        train_rewards=train_rewards,
        test_rewards=test_rewards,
        reward_shape=reward_map
    )

    # 4. Building the model

    policy = PolicyPi(
        input_dim, action_dim, hidden_size, user_embeddings,
        item_embeddings, None, pad_val, 1, device
    ).to(device)
    beta = Beta(input_dim, action_dim, hidden_size).to(device)
    init_param(policy, beta)

    policy_optim = Adam(policy.parameters(), policy_lr, weight_decay=weight_decay)
    beta_optim = Adam(beta.parameters(), beta_lr, weight_decay=weight_decay)

    model = Reinforce(
        policy,
        policy_optim,
        beta,
        beta_optim,
        hidden_size,
        gamma,
        k=10,
        weight_clip=2.0,
        offpolicy_correction=True,
        topk=True,
        adaptive_softmax=False,
        cutoffs=cutoffs,
        device=device,
    )

    var_counts = tuple(count_vars(module) for module in [policy, beta])
    print(f'Number of parameters: policy: {var_counts[0]}, '
          f' beta: {var_counts[1]}')
    
    # 5. Training the model

    train_model(
        model,
        n_epochs,
        n_rec,
        n_users,
        train_user_consumed,
        test_user_consumed,
        hist_num,
        train_loader,
        eval_loader,
        item_embeddings,
        eval_batch_size,
        pad_val,
        device,
        debug=debug,
        eval_interval=10
    )

    torch.save(policy.state_dict(), "resources/model_reinforce.pt")
    print("train and save done!")