{
  "nbformat": 4,
  "cells": [
    {
      "source": [
        "%matplotlib inline"
      ],
      "metadata": {
        "collapsed": false
      },
      "execution_count": null,
      "cell_type": "code",
      "outputs": []
    },
    {
      "source": [
        "\n# Contextual bandit on MovieLens\n\n\nThe script uses real-world data to conduct contextual bandit experiments. Here we use\nMovieLens 10M Dataset, which is released by GroupLens at 1/2009. Please fist pre-process\ndatasets (use \"movielens_preprocess.py\"), and then you can run this example.\n\n"
      ],
      "metadata": {},
      "cell_type": "markdown"
    },
    {
      "source": [
        "import pandas as pd\nimport numpy as np\nimport matplotlib.pyplot as plt\nfrom striatum.storage import history\nfrom striatum.storage import model\nfrom striatum.bandit import ucb1\nfrom striatum.bandit import linucb\nfrom striatum.bandit import linthompsamp\nfrom striatum.bandit import exp4p\nfrom striatum.bandit import exp3\nfrom striatum.bandit.bandit import Action\nfrom sklearn.naive_bayes import MultinomialNB\nfrom sklearn.linear_model import LogisticRegression\nfrom sklearn.multiclass import OneVsRestClassifier\n\n\ndef get_data():\n    streaming_batch = pd.read_csv('streaming_batch.csv', sep='\\t', names=['user_id'], engine='c')\n    user_feature = pd.read_csv('user_feature.csv', sep='\\t', header=0, index_col=0, engine='c')\n    actions_id = list(pd.read_csv('actions.csv', sep='\\t', header=0, engine='c')['movie_id'])\n    reward_list = pd.read_csv('reward_list.csv', sep='\\t', header=0, engine='c')\n    action_context = pd.read_csv('action_context.csv', sep='\\t', header=0, engine='c')\n\n    actions = []\n    for key in actions_id:\n        action = Action(key)\n        actions.append(action)\n    return streaming_batch, user_feature, actions, reward_list, action_context\n\n\ndef train_expert(action_context):\n    logreg = OneVsRestClassifier(LogisticRegression())\n    mnb = OneVsRestClassifier(MultinomialNB(), )\n    logreg.fit(action_context.iloc[:, 2:], action_context.iloc[:, 1])\n    mnb.fit(action_context.iloc[:, 2:], action_context.iloc[:, 1])\n    return [logreg, mnb]\n\n\ndef get_advice(context, actions_id, experts):\n    advice = {}\n    for time in context.keys():\n        advice[time] = {}\n        for i in range(len(experts)):\n            prob = experts[i].predict_proba(context[time])[0]\n            advice[time][i] = {}\n            for j in range(len(prob)):\n                advice[time][i][actions_id[j]] = prob[j]\n    return advice\n\n\ndef policy_generation(bandit, actions):\n    historystorage = history.MemoryHistoryStorage()\n    modelstorage = model.MemoryModelStorage()\n\n    if bandit == 'Exp4P':\n        policy = exp4p.Exp4P(actions, historystorage, modelstorage, delta=0.5, pmin=None)\n\n    elif bandit == 'LinUCB':\n        policy = linucb.LinUCB(actions, historystorage, modelstorage, 0.3, 20)\n\n    elif bandit == 'LinThompSamp':\n        policy = linthompsamp.LinThompSamp(actions, historystorage, modelstorage,\n                                           d=20, delta=0.61, r=0.01, epsilon=0.71)\n\n    elif bandit == 'UCB1':\n        policy = ucb1.UCB1(actions, historystorage, modelstorage)\n\n    elif bandit == 'Exp3':\n        policy = exp3.Exp3(actions, historystorage, modelstorage, gamma=0.2)\n\n    elif bandit == 'random':\n        policy = 0\n\n    return policy\n\n\ndef policy_evaluation(policy, bandit, streaming_batch, user_feature, reward_list, actions, action_context=None):\n    times = len(streaming_batch)\n    seq_error = np.zeros(shape=(times, 1))\n    actions_id = [actions[i].action_id for i in range(len(actions))]\n    if bandit in ['LinUCB', 'LinThompSamp', 'UCB1', 'Exp3']:\n        for t in range(times):\n            feature = np.array(user_feature[user_feature.index == streaming_batch.iloc[t, 0]])[0]\n            full_context = {}\n            for action_id in actions_id:\n                full_context[action_id] = feature\n            history_id, action = policy.get_action(full_context, 1)\n            watched_list = reward_list[reward_list['user_id'] == streaming_batch.iloc[t, 0]]\n\n            if action[0]['action'].action_id not in list(watched_list['movie_id']):\n                policy.reward(history_id, {action[0]['action'].action_id: 0.0})\n                if t == 0:\n                    seq_error[t] = 1.0\n                else:\n                    seq_error[t] = seq_error[t - 1] + 1.0\n\n            else:\n                policy.reward(history_id, {action[0]['action'].action_id: 1.0})\n                if t > 0:\n                    seq_error[t] = seq_error[t - 1]\n\n    elif bandit == 'Exp4P':\n        for t in range(times):\n            feature = user_feature[user_feature.index == streaming_batch.iloc[t, 0]]\n            experts = train_expert(action_context)\n            advice = {}\n            for i in range(len(experts)):\n                prob = experts[i].predict_proba(feature)[0]\n                advice[i] = {}\n                for j in range(len(prob)):\n                    advice[i][actions_id[j]] = prob[j]\n            history_id, action = policy.get_action(advice)\n            watched_list = reward_list[reward_list['user_id'] == streaming_batch.iloc[t, 0]]\n\n            if action[0]['action'].action_id not in list(watched_list['movie_id']):\n                policy.reward(history_id, {action[0]['action'].action_id: 0.0})\n                if t == 0:\n                    seq_error[t] = 1.0\n                else:\n                    seq_error[t] = seq_error[t - 1] + 1.0\n\n            else:\n                policy.reward(history_id, {action[0]['action'].action_id: 1.0})\n                if t > 0:\n                    seq_error[t] = seq_error[t - 1]\n\n    elif bandit == 'random':\n        for t in range(times):\n            action = actions_id[np.random.randint(0, len(actions)-1)]\n            watched_list = reward_list[reward_list['user_id'] == streaming_batch.iloc[t, 0]]\n\n            if action not in list(watched_list['movie_id']):\n                if t == 0:\n                    seq_error[t] = 1.0\n                else:\n                    seq_error[t] = seq_error[t - 1] + 1.0\n\n            else:\n                if t > 0:\n                    seq_error[t] = seq_error[t - 1]\n\n    return seq_error\n\n\ndef regret_calculation(seq_error):\n    t = len(seq_error)\n    regret = [x / y for x, y in zip(seq_error, range(1, t + 1))]\n    return regret\n\n\ndef main():\n    streaming_batch, user_feature, actions, reward_list, action_context = get_data()\n    streaming_batch_small = streaming_batch.iloc[0:10000]\n\n    # conduct regret analyses\n    experiment_bandit = ['LinUCB', 'LinThompSamp', 'Exp4P', 'UCB1', 'Exp3', 'random']\n    regret = {}\n    col = ['b', 'g', 'r', 'c', 'm', 'y', 'k', 'w']\n    i = 0\n    for bandit in experiment_bandit:\n        policy = policy_generation(bandit, actions)\n        seq_error = policy_evaluation(policy, bandit, streaming_batch_small, user_feature, reward_list,\n                                      actions, action_context)\n        regret[bandit] = regret_calculation(seq_error)\n        plt.plot(range(len(streaming_batch_small)), regret[bandit], c=col[i], ls='-', label=bandit)\n        plt.xlabel('time')\n        plt.ylabel('regret')\n        plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)\n        axes = plt.gca()\n        axes.set_ylim([0, 1])\n        plt.title(\"Regret Bound with respect to T\")\n        i += 1\n    plt.show()\n\n\nif __name__ == '__main__':\n    main()"
      ],
      "metadata": {
        "collapsed": false
      },
      "execution_count": null,
      "cell_type": "code",
      "outputs": []
    }
  ],
  "metadata": {
    "kernelspec": {
      "language": "python",
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "version": 3,
        "name": "ipython"
      },
      "file_extension": ".py",
      "pygments_lexer": "ipython3",
      "nbconvert_exporter": "python",
      "mimetype": "text/x-python",
      "version": "3.4.3",
      "name": "python"
    }
  },
  "nbformat_minor": 0
}