Top-K Off-Policy Correction for a REINFORCE Recommender System¶
CLI run¶
!gdown --id 1erBjYEOa7IuOIGpI8pGPn1WNBAC4Rv0-
!git clone https://github.com/massquantity/DBRL.git
!unzip /content/ECommAI_EUIR_round2_train_20190821.zip
!mv ECommAI_EUIR_round2_train_20190816/*.csv DBRL/dbrl/resources
%cd DBRL/dbrl
Downloading...
From: https://drive.google.com/uc?id=1erBjYEOa7IuOIGpI8pGPn1WNBAC4Rv0-
To: /content/ECommAI_EUIR_round2_train_20190821.zip
100% 894M/894M [00:06<00:00, 146MB/s]
Cloning into 'DBRL'...
remote: Enumerating objects: 118, done.
remote: Counting objects: 100% (118/118), done.
remote: Compressing objects: 100% (83/83), done.
remote: Total 118 (delta 29), reused 114 (delta 25), pack-reused 0
Receiving objects: 100% (118/118), 203.89 KiB | 2.87 MiB/s, done.
Resolving deltas: 100% (29/29), done.
Archive: /content/ECommAI_EUIR_round2_train_20190821.zip
creating: ECommAI_EUIR_round2_train_20190816/
inflating: ECommAI_EUIR_round2_train_20190816/user_behavior.csv
inflating: ECommAI_EUIR_round2_train_20190816/item.csv
inflating: ECommAI_EUIR_round2_train_20190816/user.csv
/content/DBRL/dbrl
!python run_prepare_data.py
{'seed': 0}
tcmalloc: large alloc 1931411456 bytes == 0x559fe1e9e000 @ 0x7fbeb7c811e7 0x7fbeb580146e 0x7fbeb5851c7b 0x7fbeb585235f 0x7fbeb58f4103 0x559fd33e4544 0x559fd33e4240 0x559fd3458627 0x559fd33e5afa 0x559fd3453915 0x559fd34529ee 0x559fd33e5bda 0x559fd3453915 0x559fd33e5afa 0x559fd3453915 0x559fd33e5afa 0x559fd3453915 0x559fd34529ee 0x559fd33e5bda 0x559fd3454737 0x559fd3452ced 0x559fd33e5bda 0x559fd3454737 0x559fd3452ced 0x559fd33e648c 0x559fd3427159 0x559fd34240a4 0x559fd33e4d49 0x559fd345894f 0x559fd34529ee 0x559fd33e5bda
tcmalloc: large alloc 1931411456 bytes == 0x55a1625c8000 @ 0x7fbeb7c811e7 0x7fbeb580146e 0x7fbeb5851c7b 0x7fbeb585235f 0x7fbeb58f4103 0x559fd33e4544 0x559fd33e4240 0x559fd3458627 0x559fd34529ee 0x559fd33e5bda 0x559fd3454737 0x559fd34529ee 0x559fd33e5bda 0x559fd3454737 0x559fd3452ced 0x559fd33e5bda 0x559fd3453915 0x559fd3452ced 0x559fd33e5bda 0x559fd3454737 0x559fd34529ee 0x559fd33e5bda 0x559fd3454737 0x559fd3452ced 0x559fd33e5bda 0x559fd3454737 0x559fd34529ee 0x559fd34526f3 0x559fd351c4c2 0x559fd351c83d 0x559fd351c6e6
tcmalloc: large alloc 1931411456 bytes == 0x55a1d57b8000 @ 0x7fbeb7c811e7 0x7fbeb580146e 0x7fbeb5851c7b 0x7fbeb5851d97 0x7fbeb584b4a5 0x7fbeb58e8eab 0x559fd33e44b0 0x559fd34d5e1d 0x559fd3457e99 0x559fd34529ee 0x559fd33e648c 0x559fd33e6698 0x559fd3454fe4 0x559fd3452ced 0x559fd33e5bda 0x559fd3454737 0x559fd3452ced 0x559fd33e5bda 0x559fd3454737 0x559fd34529ee 0x559fd33e5bda 0x559fd3457d00 0x559fd3452ced 0x559fd33e5bda 0x559fd3454737 0x559fd34529ee 0x559fd34526f3 0x559fd351c4c2 0x559fd351c83d 0x559fd351c6e6 0x559fd34f4163
tcmalloc: large alloc 1548664832 bytes == 0x559fe1e9e000 @ 0x7fbeb7c811e7 0x7fbeb580146e 0x7fbeb5851c7b 0x7fbeb585235f 0x7fbeb58f4103 0x559fd33e4544 0x559fd33e4240 0x559fd3458627 0x559fd34529ee 0x559fd33e5bda 0x559fd3454737 0x559fd34529ee 0x559fd33e5bda 0x559fd3454737 0x559fd3452ced 0x559fd33e5bda 0x559fd3453915 0x559fd3452ced 0x559fd33e5bda 0x559fd3454737 0x559fd34529ee 0x559fd33e5bda 0x559fd3454737 0x559fd34529ee 0x559fd33e5bda 0x559fd3454737 0x559fd34529ee 0x559fd33e5bda 0x559fd3454737 0x559fd33e5afa 0x559fd3457d00
n_users: 80000, n_items: 1047166, behavior length: 3234367
prepare data done!, time elapsed: 173.06
!python run_pretrain_embeddings.py --lr 0.001 --n_epochs 4
A list all args:
======================
{'batch_size': 2048,
'data': 'tianchi.csv',
'embed_size': 32,
'loss': 'cosine',
'lr': 0.001,
'n_epochs': 4,
'neg_item': 1,
'seed': 0}
n_users: 80000, n_items: 912114, train_shape: (2587518, 10), eval_shape: (646849, 10)
100% 2527/2527 [02:22<00:00, 17.73it/s]
100% 632/632 [00:11<00:00, 56.31it/s]
epoch 1, train_loss: 0.3253, eval loss: 0.3537, eval roc: 0.7370
100% 2527/2527 [02:21<00:00, 17.85it/s]
100% 632/632 [00:11<00:00, 55.87it/s]
epoch 2, train_loss: 0.2568, eval loss: 0.3351, eval roc: 0.7697
100% 2527/2527 [02:21<00:00, 17.84it/s]
100% 632/632 [00:11<00:00, 56.34it/s]
epoch 3, train_loss: 0.2260, eval loss: 0.3309, eval roc: 0.7772
100% 2527/2527 [02:20<00:00, 17.96it/s]
100% 632/632 [00:11<00:00, 57.03it/s]
epoch 4, train_loss: 0.2036, eval loss: 0.3296, eval roc: 0.7829
user_embeds shape: (80000, 32), item_embeds shape: (912115, 32)
pretrain embeddings done!
!python run_reinforce.py --n_epochs 1 --lr 1e-5
A list all args:
======================
{'batch_size': 128,
'data': 'tianchi.csv',
'gamma': 0.99,
'hidden_size': 64,
'hist_num': 10,
'item_embeds': 'tianchi_item_embeddings.npy',
'lr': 1e-05,
'n_epochs': 1,
'n_rec': 10,
'seed': 0,
'sess_mode': 'interval',
'user_embeds': 'tianchi_user_embeddings.npy',
'weight_decay': 0.0}
Number of parameters: policy: 118628454, beta: 59310067
Caution: Will compute loss every 10 step(s)
Epoch 1 start-time: 2021-10-21 10:46:33
train: 100% 19590/19590 [2:24:05<00:00, 2.27it/s]
last_eval: 100% 625/625 [00:47<00:00, 13.12it/s]
policy_loss: 665.2355, beta_loss: 13.6349, importance_weight: 0.8856, lambda_k: 9.9993,
reward: 455, ndcg_next_item: 0.000999, ndcg_all_item: 0.039790, ndcg: 0.027800
******************** EVAL ********************
eval: 100% 10516/10516 [20:26<00:00, 8.58it/s]
last_eval: 100% 625/625 [00:47<00:00, 13.12it/s]
policy_loss: 1333.3558, beta_loss: 13.5444, importance_weight: 0.9655, lambda_k: 9.9992,
reward: 290, ndcg_next_item: 0.001165, ndcg_all_item: 0.008780, ndcg: 0.007617
================================================================================
train and save done!
!apt-get -qq install tree
!tree --du -h .
Selecting previously unselected package tree.
(Reading database ... 155047 files and directories currently installed.)
Preparing to unpack .../tree_1.7.0-5_amd64.deb ...
Unpacking tree (1.7.0-5) ...
Setting up tree (1.7.0-5) ...
Processing triggers for man-db (2.8.3-2ubuntu0.1) ...
.
├── [ 56K] data
│ ├── [7.7K] dataset.py
│ ├── [ 126] __init__.py
│ ├── [8.4K] process.py
│ ├── [ 24K] __pycache__
│ │ ├── [5.5K] dataset.cpython-37.pyc
│ │ ├── [ 284] __init__.cpython-37.pyc
│ │ ├── [5.6K] process.cpython-37.pyc
│ │ ├── [6.3K] session.cpython-37.pyc
│ │ └── [2.2K] split.cpython-37.pyc
│ ├── [9.6K] session.py
│ └── [2.5K] split.py
├── [ 17K] evaluate
│ ├── [4.1K] evaluate.py
│ ├── [ 45] __init__.py
│ ├── [1.6K] metrics.py
│ └── [7.3K] __pycache__
│ ├── [1.9K] evaluate.cpython-37.pyc
│ ├── [ 178] __init__.cpython-37.pyc
│ └── [1.2K] metrics.cpython-37.pyc
├── [ 0] __init__.py
├── [ 44K] models
│ ├── [7.6K] bcq.py
│ ├── [4.2K] ddpg.py
│ ├── [3.3K] dssm.py
│ ├── [ 103] __init__.py
│ ├── [ 19K] __pycache__
│ │ ├── [5.1K] bcq.cpython-37.pyc
│ │ ├── [3.3K] ddpg.cpython-37.pyc
│ │ ├── [2.4K] dssm.cpython-37.pyc
│ │ ├── [ 256] __init__.cpython-37.pyc
│ │ └── [3.9K] youtube_topk.cpython-37.pyc
│ └── [5.7K] youtube_topk.py
├── [ 24K] network
│ ├── [ 20] __init__.py
│ ├── [9.1K] net.py
│ └── [ 11K] __pycache__
│ ├── [ 134] __init__.cpython-37.pyc
│ └── [7.2K] net.cpython-37.pyc
├── [4.1K] __pycache__
│ └── [ 106] __init__.cpython-37.pyc
├── [3.0G] resources
│ ├── [ 0] aa
│ ├── [114M] item.csv
│ ├── [ 15M] item_map.json
│ ├── [574M] model_reinforce.pt
│ ├── [160M] tianchi.csv
│ ├── [111M] tianchi_item_embeddings.npy
│ ├── [9.8M] tianchi_user_embeddings.npy
│ ├── [2.0G] user_behavior.csv
│ ├── [ 19M] user.csv
│ └── [1.2M] user_map.json
├── [5.9K] run_bcq.py
├── [5.1K] run_ddpg.py
├── [2.5K] run_prepare_data.py
├── [4.4K] run_pretrain_embeddings.py
├── [5.1K] run_reinforce.py
├── [9.9K] serialization
│ ├── [ 43] __init__.py
│ ├── [5.0K] __pycache__
│ │ ├── [ 182] __init__.cpython-37.pyc
│ │ └── [ 845] serialize.cpython-37.pyc
│ └── [ 889] serialize.py
├── [ 18K] trainer
│ ├── [ 68] __init__.py
│ ├── [1.9K] pretrain.py
│ ├── [8.2K] __pycache__
│ │ ├── [ 202] __init__.cpython-37.pyc
│ │ ├── [1.5K] pretrain.cpython-37.pyc
│ │ └── [2.5K] train.cpython-37.pyc
│ └── [3.9K] train.py
└── [ 21K] utils
├── [1.7K] info.py
├── [ 156] __init__.py
├── [2.2K] misc.py
├── [1.7K] params.py
├── [ 10K] __pycache__
│ ├── [1.5K] info.cpython-37.pyc
│ ├── [ 325] __init__.cpython-37.pyc
│ ├── [2.0K] misc.cpython-37.pyc
│ ├── [1.5K] params.cpython-37.pyc
│ └── [ 920] sampling.cpython-37.pyc
└── [ 908] sampling.py
3.0G used in 16 directories, 67 files
Code analysis¶
Data preparation¶
import os
import sys
sys.path.append(os.pardir)
import warnings
warnings.filterwarnings("ignore")
import argparse
import time
import numpy as np
import pandas as pd
def parse_args():
parser = argparse.ArgumentParser(description="run_prepare_data")
parser.add_argument("--seed", type=int, default=0)
return parser.parse_args(args={})
def bucket_age(age):
if age < 30:
return 1
elif age < 40:
return 2
elif age < 50:
return 3
else:
return 4
if __name__ == "__main__":
args = parse_args()
print(vars(args))
np.random.seed(args.seed)
start_time = time.perf_counter()
# 1. loading the data into memory
user_feat = pd.read_csv("resources/user.csv", header=None,
names=["user", "sex", "age", "pur_power"])
item_feat = pd.read_csv("resources/item.csv", header=None,
names=["item", "category", "shop", "brand"])
behavior = pd.read_csv("resources/user_behavior.csv", header=None,
names=["user", "item", "behavior", "time"])
# 2. sorting values chronologically and dropping duplicate records
behavior = behavior.sort_values(by="time").reset_index(drop=True)
behavior = behavior.drop_duplicates(subset=["user", "item", "behavior"])
# 3. Choosing 60K random users with short journey and 20K with long journey
user_counts = behavior.groupby("user")[["user"]].count().rename(
columns={"user": "count_user"}
).sort_values("count_user", ascending=False)
short_users = np.array(
user_counts[
(user_counts.count_user > 5) & (user_counts.count_user <= 50)
].index
)
long_users = np.array(
user_counts[
(user_counts.count_user > 50) & (user_counts.count_user <= 200)
].index
)
short_chosen_users = np.random.choice(short_users, 60000, replace=False)
long_chosen_users = np.random.choice(long_users, 20000, replace=False)
chosen_users = np.concatenate([short_chosen_users, long_chosen_users])
behavior = behavior[behavior.user.isin(chosen_users)]
print(f"n_users: {behavior.user.nunique()}, "
f"n_items: {behavior.item.nunique()}, "
f"behavior length: {len(behavior)}")
# 4. merge with all features, bucketizing the age and saving the processed data
behavior = behavior.merge(user_feat, on="user")
behavior = behavior.merge(item_feat, on="item")
behavior["age"] = behavior["age"].apply(bucket_age)
behavior = behavior.sort_values(by="time").reset_index(drop=True)
behavior.to_csv("resources/tianchi.csv", header=None, index=False)
print(f"prepare data done!, "
f"time elapsed: {(time.perf_counter() - start_time):.2f}")
Embeddings¶
import os
import sys
sys.path.append(os.pardir)
import warnings
warnings.filterwarnings("ignore")
import argparse
from pprint import pprint
import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from dbrl.data import process_feat_data, FeatDataset
from dbrl.models import DSSM
from dbrl.utils import sample_items_random, init_param_dssm, generate_embeddings
from dbrl.trainer import pretrain_model
from dbrl.serialization import save_npy, save_json
def parse_args():
parser = argparse.ArgumentParser(description="run_pretrain_embeddings")
parser.add_argument("--data", type=str, default="tianchi.csv")
parser.add_argument("--n_epochs", type=int, default=10)
parser.add_argument("--batch_size", type=int, default=2048)
parser.add_argument("--lr", type=float, default=5e-4)
parser.add_argument("--embed_size", type=int, default=32)
parser.add_argument("--loss", type=str, default="cosine",
help="cosine or bce loss")
parser.add_argument("--neg_item", type=int, default=1)
parser.add_argument("--seed", type=int, default=0)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
print("A list all args: \n======================")
pprint(vars(args))
print()
# 1. Setting arguments/params
torch.manual_seed(args.seed)
np.random.seed(args.seed)
PATH = os.path.join("resources", args.data)
EMBEDDING_PATH = "resources/"
static_feat = ["sex", "age", "pur_power"]
dynamic_feat = ["category", "shop", "brand"]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_epochs = args.n_epochs
batch_size = args.batch_size
lr = args.lr
item_embed_size = args.embed_size
feat_embed_size = args.embed_size
hidden_size = (256, 128)
criterion = (
nn.CosineEmbeddingLoss()
if args.loss == "cosine"
else nn.BCEWithLogitsLoss()
)
criterion_type = (
"cosine"
if "cosine" in criterion.__class__.__name__.lower()
else "bce"
)
neg_label = -1. if criterion_type == "cosine" else 0.
neg_item = args.neg_item
# 2. Preprocessing
columns = ["user", "item", "label", "time", "sex", "age", "pur_power",
"category", "shop", "brand"]
(
n_users,
n_items,
train_user_consumed,
eval_user_consumed,
train_data,
eval_data,
user_map,
item_map,
feat_map
) = process_feat_data(
PATH, columns, test_size=0.2, time_col="time",
static_feat=static_feat, dynamic_feat=dynamic_feat
)
print(f"n_users: {n_users}, n_items: {n_items}, "
f"train_shape: {train_data.shape}, eval_shape: {eval_data.shape}")
# 3. Random negative sampling
train_user, train_item, train_label = sample_items_random(
train_data, n_items, train_user_consumed, neg_label, neg_item
)
eval_user, eval_item, eval_label = sample_items_random(
eval_data, n_items, eval_user_consumed, neg_label, neg_item
)
# 4. Putting data into torch dataset format and dataloader
train_dataset = FeatDataset(
train_user,
train_item,
train_label,
feat_map,
static_feat,
dynamic_feat
)
eval_dataset = FeatDataset(
eval_user,
eval_item,
eval_label,
feat_map,
static_feat,
dynamic_feat
)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size,
shuffle=True, num_workers=0)
eval_loader = DataLoader(dataset=eval_dataset, batch_size=batch_size,
shuffle=False, num_workers=0)
# 5. DSSM embedding model training
model = DSSM(
item_embed_size,
feat_embed_size,
n_users,
n_items,
hidden_size,
feat_map,
static_feat,
dynamic_feat,
use_bn=True
).to(device)
init_param_dssm(model)
optimizer = Adam(model.parameters(), lr=lr) # weight_decay
pretrain_model(model, train_loader, eval_loader, n_epochs, criterion,
criterion_type, optimizer, device)
# 6. Generate and save embeddings
user_embeddings, item_embeddings = generate_embeddings(
model, n_users, n_items, feat_map, static_feat, dynamic_feat, device
)
print(f"user_embeds shape: {user_embeddings.shape},"
f" item_embeds shape: {item_embeddings.shape}")
save_npy(user_embeddings, item_embeddings, EMBEDDING_PATH)
save_json(
user_map, item_map, user_embeddings, item_embeddings, EMBEDDING_PATH
)
print("pretrain embeddings done!")
REINFORCE model¶
import torch
import torch.nn as nn
from torch.distributions import Categorical
class Reinforce(nn.Module):
def __init__(
self,
policy,
policy_optim,
beta,
beta_optim,
hidden_size,
gamma=0.99,
k=10,
weight_clip=2.0,
offpolicy_correction=True,
topk=True,
adaptive_softmax=True,
cutoffs=None,
device=torch.device("cpu"),
):
super(Reinforce, self).__init__()
self.policy = policy
self.policy_optim = policy_optim
self.beta = beta
self.beta_optim = beta_optim
self.beta_criterion = nn.CrossEntropyLoss()
self.gamma = gamma
self.k = k
self.weight_clip = weight_clip
self.offpolicy_correction = offpolicy_correction
self.topk = topk
self.adaptive_softmax = adaptive_softmax
if adaptive_softmax:
assert cutoffs is not None, (
"must provide cutoffs when using adaptive_softmax"
)
self.softmax_loss = nn.AdaptiveLogSoftmaxWithLoss(
in_features=hidden_size,
n_classes=policy.item_embeds.weight.size(0),
cutoffs=cutoffs,
div_value=4.
).to(device)
self.device = device
def update(self, data):
(
policy_loss,
beta_loss,
action,
importance_weight,
lambda_k
) = self._compute_loss(data)
self.policy_optim.zero_grad()
policy_loss.backward()
self.policy_optim.step()
self.beta_optim.zero_grad()
beta_loss.backward()
self.beta_optim.step()
info = {'policy_loss': policy_loss.cpu().detach().item(),
'beta_loss': beta_loss.cpu().detach().item(),
'importance_weight': importance_weight.cpu().mean().item(),
'lambda_k': lambda_k.cpu().mean().item(),
'action': action}
return info
def _compute_weight(self, policy_logp, beta_logp):
if self.offpolicy_correction:
importance_weight = torch.exp(policy_logp - beta_logp).detach()
wc = torch.tensor([self.weight_clip]).to(self.device)
importance_weight = torch.min(importance_weight, wc)
# importance_weight = torch.clamp(
# importance_weight, self.weight_clip[0], self.weight_clip[1]
# )
else:
importance_weight = torch.tensor([1.]).float().to(self.device)
return importance_weight
def _compute_lambda_k(self, policy_logp):
lam = (
self.k * ((1. - policy_logp.exp()).pow(self.k - 1)).detach()
if self.topk
else torch.tensor([1.]).float().to(self.device)
)
return lam
def _compute_loss(self, data):
if self.adaptive_softmax:
state, action = self.policy(data)
policy_out = self.softmax_loss(action, data["action"])
policy_logp = policy_out.output
beta_action = self.beta(state.detach())
beta_out = self.softmax_loss(beta_action, data["action"])
beta_logp = beta_out.output
else:
state, all_logp, action = self.policy.get_log_probs(data)
policy_logp = all_logp[:, data["action"]]
b_logp, beta_logits = self.beta.get_log_probs(state.detach())
beta_logp = (b_logp[:, data["action"]]).detach()
importance_weight = self._compute_weight(policy_logp, beta_logp)
lambda_k = self._compute_lambda_k(policy_logp)
policy_loss = -(
importance_weight * lambda_k * data["return"] * policy_logp
).mean()
if self.adaptive_softmax:
if "beta_label" in data:
b_state = self.policy.get_beta_state(data)
b_action = self.beta(b_state.detach())
b_out = self.softmax_loss(b_action, data["beta_label"])
beta_loss = b_out.loss
else:
beta_loss = beta_out.loss
else:
if "beta_label" in data:
b_state = self.policy.get_beta_state(data)
_, b_logits = self.beta.get_log_probs(b_state.detach())
beta_loss = self.beta_criterion(b_logits, data["beta_label"])
else:
beta_loss = self.beta_criterion(beta_logits, data["action"])
return policy_loss, beta_loss, action, importance_weight, lambda_k
def compute_loss(self, data):
(
policy_loss,
beta_loss,
action,
importance_weight,
lambda_k
) = self._compute_loss(data)
info = {'policy_loss': policy_loss.cpu().detach().item(),
'beta_loss': beta_loss.cpu().detach().item(),
'importance_weight': importance_weight.cpu().mean().item(),
'lambda_k': lambda_k.cpu().mean().item(),
'action': action}
return info
def get_log_probs(self, data=None, action=None):
with torch.no_grad():
if self.adaptive_softmax:
if action is None:
_, action = self.policy.forward(data)
log_probs = self.softmax_loss.log_prob(action)
else:
# _, log_probs = self.policy.get_log_probs(data)
if action is None:
_, action = self.policy.forward(data)
log_probs = self.policy.softmax_fc(action)
return log_probs
def forward(self, state):
policy_logits = self.policy.get_action(state)
policy_dist = Categorical(logits=policy_logits)
_, rec_idxs = torch.topk(policy_dist.probs, 10, dim=1)
return rec_idxs
Trainer¶
import os
import sys
sys.path.append(os.pardir)
import warnings
warnings.filterwarnings("ignore")
import argparse
from pprint import pprint
import numpy as np
import torch
from torch.optim import Adam
from dbrl.data import process_data, build_dataloader
from dbrl.models import Reinforce
from dbrl.network import PolicyPi, Beta
from dbrl.trainer import train_model
from dbrl.utils import count_vars, init_param
def parse_args():
parser = argparse.ArgumentParser(description="run_reinforce")
parser.add_argument("--data", type=str, default="tianchi.csv")
parser.add_argument("--user_embeds", type=str,
default="tianchi_user_embeddings.npy")
parser.add_argument("--item_embeds", type=str,
default="tianchi_item_embeddings.npy")
parser.add_argument("--n_epochs", type=int, default=100)
parser.add_argument("--hist_num", type=int, default=10,
help="num of history items to consider")
parser.add_argument("--n_rec", type=int, default=10,
help="num of items to recommend")
parser.add_argument("--batch_size", type=int, default=128)
parser.add_argument("--hidden_size", type=int, default=64)
parser.add_argument("--lr", type=float, default=1e-5)
parser.add_argument("--weight_decay", type=float, default=0.)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--sess_mode", type=str, default="interval",
help="Specify when to end a session")
parser.add_argument("--seed", type=int, default=0)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
print("A list all args: \n======================")
pprint(vars(args))
print()
# 1. Loading user and item embeddings
torch.manual_seed(args.seed)
np.random.seed(args.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
PATH = os.path.join("resources", args.data)
with open(os.path.join("resources", args.user_embeds), "rb") as f:
user_embeddings = np.load(f)
with open(os.path.join("resources", args.item_embeds), "rb") as f:
item_embeddings = np.load(f)
item_embeddings[-1] = 0. # last item is used for padding
# 2. Setting model arguments/params
n_epochs = args.n_epochs
hist_num = args.hist_num
batch_size = eval_batch_size = args.batch_size
embed_size = item_embeddings.shape[1]
hidden_size = args.hidden_size
input_dim = embed_size * (hist_num + 1)
action_dim = len(item_embeddings)
policy_lr = args.lr
beta_lr = args.lr
weight_decay = args.weight_decay
gamma = args.gamma
n_rec = args.n_rec
pad_val = len(item_embeddings) - 1
sess_mode = args.sess_mode
debug = True
one_hour = int(60 * 60)
reward_map = {"pv": 1., "cart": 2., "fav": 2., "buy": 3.}
columns = ["user", "item", "label", "time", "sex", "age", "pur_power",
"category", "shop", "brand"]
cutoffs = [
len(item_embeddings) // 20,
len(item_embeddings) // 10,
len(item_embeddings) // 3
]
# 3. Building the data loader
(
n_users,
n_items,
train_user_consumed,
test_user_consumed,
train_sess_end,
test_sess_end,
train_rewards,
test_rewards
) = process_data(PATH, columns, 0.2, time_col="time", sess_mode=sess_mode,
interval=one_hour, reward_shape=reward_map)
train_loader, eval_loader = build_dataloader(
n_users,
n_items,
hist_num,
train_user_consumed,
test_user_consumed,
batch_size,
sess_mode=sess_mode,
train_sess_end=train_sess_end,
test_sess_end=test_sess_end,
n_workers=0,
compute_return=True,
neg_sample=False,
train_rewards=train_rewards,
test_rewards=test_rewards,
reward_shape=reward_map
)
# 4. Building the model
policy = PolicyPi(
input_dim, action_dim, hidden_size, user_embeddings,
item_embeddings, None, pad_val, 1, device
).to(device)
beta = Beta(input_dim, action_dim, hidden_size).to(device)
init_param(policy, beta)
policy_optim = Adam(policy.parameters(), policy_lr, weight_decay=weight_decay)
beta_optim = Adam(beta.parameters(), beta_lr, weight_decay=weight_decay)
model = Reinforce(
policy,
policy_optim,
beta,
beta_optim,
hidden_size,
gamma,
k=10,
weight_clip=2.0,
offpolicy_correction=True,
topk=True,
adaptive_softmax=False,
cutoffs=cutoffs,
device=device,
)
var_counts = tuple(count_vars(module) for module in [policy, beta])
print(f'Number of parameters: policy: {var_counts[0]}, '
f' beta: {var_counts[1]}')
# 5. Training the model
train_model(
model,
n_epochs,
n_rec,
n_users,
train_user_consumed,
test_user_consumed,
hist_num,
train_loader,
eval_loader,
item_embeddings,
eval_batch_size,
pad_val,
device,
debug=debug,
eval_interval=10
)
torch.save(policy.state_dict(), "resources/model_reinforce.pt")
print("train and save done!")