diff --git a/cmsc320_hw4.ipynb b/cmsc320_hw4.ipynb
new file mode 100644
index 00000000..8ab9fbfc
--- /dev/null
+++ b/cmsc320_hw4.ipynb
@@ -0,0 +1,770 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "provenance": [],
+ "authorship_tag": "ABX9TyNlOgZTGUUqhQby0su1ySpw",
+ "include_colab_link": true
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ "
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from google.colab import drive\n",
+ "drive.mount('/content/drive')"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "4n8qy4J_Tb7B",
+ "outputId": "c16ec0f2-dbb4-4d6d-f491-e25edf3dcff8"
+ },
+ "execution_count": 1,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {
+ "id": "1fERlpkyTSF4"
+ },
+ "outputs": [],
+ "source": [
+ "import pandas as pd\n",
+ "import pprint\n",
+ "import os\n",
+ "import numpy as np\n",
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "from sklearn.neural_network import MLPClassifier\n",
+ "from sklearn.svm import LinearSVC\n",
+ "from sklearn.metrics import confusion_matrix\n",
+ "from sklearn.model_selection import KFold\n",
+ "from sklearn.neural_network import MLPClassifier\n",
+ "from sklearn.tree import DecisionTreeClassifier\n",
+ "from sklearn.decomposition import PCA"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Loading data"
+ ],
+ "metadata": {
+ "id": "Pd-palSEoc54"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# 60/20/20 train/test/valid split\n",
+ "\n",
+ "path = \"./drive/MyDrive/cmsc320/HW4/homework4.csv\"\n",
+ "df = pd.read_csv(path)\n",
+ "\n",
+ "index = [\"%d\"%i for i in range(12)] + [\"Results\"]\n",
+ "data = np.array(df[index])\n",
+ "\n",
+ "#regularize\n",
+ "means = np.average(data[:,:12], axis=0)\n",
+ "stds = np.std(data[:,:12], axis=0)\n",
+ "data[:,:12] = np.subtract(np.divide(data[:,:12], stds), means)\n",
+ "\n",
+ "#shuffle data\n",
+ "np.random.shuffle(data)"
+ ],
+ "metadata": {
+ "id": "vyu85xCKTtWt"
+ },
+ "execution_count": 4,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "This is model 1 trained on 10-fold CV: A SVM. The model achieves a peak accuracy of 75.1% accuracy and a peak precision of 32.0% across all runs"
+ ],
+ "metadata": {
+ "id": "xn_67dMSoKkM"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# 10-fold Cross validation\n",
+ "kfold = KFold(n_splits=10)\n",
+ "\n",
+ "for i, (train_index, test_index) in enumerate(kfold.split(data)):\n",
+ " train = data[train_index, :]\n",
+ " test = data[test_index, :]\n",
+ "\n",
+ " svm_classifier = LinearSVC(class_weight='balanced')\n",
+ " svm_classifier.fit(train[:, :12], train[:,12])\n",
+ "\n",
+ " output = svm_classifier.predict(test[:,:12])\n",
+ " #C_ij = i actual, j predicted\n",
+ " c_matrix = confusion_matrix(test[:,12], output)\n",
+ " print(\"confusion matrix:\")\n",
+ " print(c_matrix)\n",
+ " print(\"true positive:\\t%d\\ntrue negative:\\t%d\\nfalse positive:\\t%d\\nfalse negative:\\t%d\"%(c_matrix[1,1], c_matrix[0,0], c_matrix[0,1], c_matrix[1,0]))\n",
+ " precision = c_matrix[1,1] / (c_matrix[1,1] + c_matrix[0,1])\n",
+ " recall = c_matrix[1,1] / (c_matrix[1,1] + c_matrix[1,0])\n",
+ " print(\"\\naccuracy:\\t%f\"%(np.sum(np.equal(output, test[:,12])) / 1000))\n",
+ " print(\"precision:\\t%f\\nrecall:\\t\\t%f\"%(precision, recall))\n",
+ " print(\"\\n\")"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "Ov6GhBNMlQ-E",
+ "outputId": "ccc87c4e-d257-4625-9f91-b697776a5a6b"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "confusion matrix:\n",
+ "[[612 239]\n",
+ " [ 52 97]]\n",
+ "true positive:\t97\n",
+ "true negative:\t612\n",
+ "false positive:\t239\n",
+ "false negative:\t52\n",
+ "\n",
+ "accuracy:\t0.709000\n",
+ "precision:\t0.288690\n",
+ "recall:\t\t0.651007\n",
+ "\n",
+ "\n",
+ "confusion matrix:\n",
+ "[[637 234]\n",
+ " [ 45 84]]\n",
+ "true positive:\t84\n",
+ "true negative:\t637\n",
+ "false positive:\t234\n",
+ "false negative:\t45\n",
+ "\n",
+ "accuracy:\t0.721000\n",
+ "precision:\t0.264151\n",
+ "recall:\t\t0.651163\n",
+ "\n",
+ "\n",
+ "confusion matrix:\n",
+ "[[653 193]\n",
+ " [ 63 91]]\n",
+ "true positive:\t91\n",
+ "true negative:\t653\n",
+ "false positive:\t193\n",
+ "false negative:\t63\n",
+ "\n",
+ "accuracy:\t0.744000\n",
+ "precision:\t0.320423\n",
+ "recall:\t\t0.590909\n",
+ "\n",
+ "\n",
+ "confusion matrix:\n",
+ "[[650 196]\n",
+ " [ 58 96]]\n",
+ "true positive:\t96\n",
+ "true negative:\t650\n",
+ "false positive:\t196\n",
+ "false negative:\t58\n",
+ "\n",
+ "accuracy:\t0.746000\n",
+ "precision:\t0.328767\n",
+ "recall:\t\t0.623377\n",
+ "\n",
+ "\n",
+ "confusion matrix:\n",
+ "[[644 221]\n",
+ " [ 55 80]]\n",
+ "true positive:\t80\n",
+ "true negative:\t644\n",
+ "false positive:\t221\n",
+ "false negative:\t55\n",
+ "\n",
+ "accuracy:\t0.724000\n",
+ "precision:\t0.265781\n",
+ "recall:\t\t0.592593\n",
+ "\n",
+ "\n",
+ "confusion matrix:\n",
+ "[[633 223]\n",
+ " [ 45 99]]\n",
+ "true positive:\t99\n",
+ "true negative:\t633\n",
+ "false positive:\t223\n",
+ "false negative:\t45\n",
+ "\n",
+ "accuracy:\t0.732000\n",
+ "precision:\t0.307453\n",
+ "recall:\t\t0.687500\n",
+ "\n",
+ "\n",
+ "confusion matrix:\n",
+ "[[648 220]\n",
+ " [ 47 85]]\n",
+ "true positive:\t85\n",
+ "true negative:\t648\n",
+ "false positive:\t220\n",
+ "false negative:\t47\n",
+ "\n",
+ "accuracy:\t0.733000\n",
+ "precision:\t0.278689\n",
+ "recall:\t\t0.643939\n",
+ "\n",
+ "\n",
+ "confusion matrix:\n",
+ "[[666 196]\n",
+ " [ 53 85]]\n",
+ "true positive:\t85\n",
+ "true negative:\t666\n",
+ "false positive:\t196\n",
+ "false negative:\t53\n",
+ "\n",
+ "accuracy:\t0.751000\n",
+ "precision:\t0.302491\n",
+ "recall:\t\t0.615942\n",
+ "\n",
+ "\n",
+ "confusion matrix:\n",
+ "[[662 208]\n",
+ " [ 44 86]]\n",
+ "true positive:\t86\n",
+ "true negative:\t662\n",
+ "false positive:\t208\n",
+ "false negative:\t44\n",
+ "\n",
+ "accuracy:\t0.748000\n",
+ "precision:\t0.292517\n",
+ "recall:\t\t0.661538\n",
+ "\n",
+ "\n",
+ "confusion matrix:\n",
+ "[[634 222]\n",
+ " [ 57 87]]\n",
+ "true positive:\t87\n",
+ "true negative:\t634\n",
+ "false positive:\t222\n",
+ "false negative:\t57\n",
+ "\n",
+ "accuracy:\t0.721000\n",
+ "precision:\t0.281553\n",
+ "recall:\t\t0.604167\n",
+ "\n",
+ "\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "This is model 2: A neural network. The model achieves above a 90% accuracy on almost every iteration of the 10-fold CV and has above 50% precision for every iteration"
+ ],
+ "metadata": {
+ "id": "air1sHHBn3zr"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# 10-fold Cross validation\n",
+ "kfold = KFold(n_splits=10)\n",
+ "\n",
+ "for i, (train_index, test_index) in enumerate(kfold.split(data)):\n",
+ " train = data[train_index, :]\n",
+ " test = data[test_index, :]\n",
+ "\n",
+ " nn_classifier = MLPClassifier()\n",
+ " nn_classifier.fit(train[:, :12], train[:,12])\n",
+ "\n",
+ " output = nn_classifier.predict(test[:,:12])\n",
+ " #C_ij = i actual, j predicted\n",
+ " c_matrix = confusion_matrix(test[:,12], output)\n",
+ " print(\"confusion matrix:\")\n",
+ " print(c_matrix)\n",
+ " print(\"true positive:\\t%d\\ntrue negative:\\t%d\\nfalse positive:\\t%d\\nfalse negative:\\t%d\"%(c_matrix[1,1], c_matrix[0,0], c_matrix[0,1], c_matrix[1,0]))\n",
+ " precision = c_matrix[1,1] / (c_matrix[1,1] + c_matrix[0,1])\n",
+ " recall = c_matrix[1,1] / (c_matrix[1,1] + c_matrix[1,0])\n",
+ " print(\"\\naccuracy:\\t%f\"%(np.sum(np.equal(output, test[:,12])) / 1000))\n",
+ " print(\"precision:\\t%f\\nrecall:\\t\\t%f\"%(precision, recall))\n",
+ " print(\"\\n\")"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "oTjOYzBub9qx",
+ "outputId": "e75ae2af-9651-4f98-dc3e-fa7720e61ad8"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "confusion matrix:\n",
+ "[[846 5]\n",
+ " [ 83 66]]\n",
+ "true positive:\t66\n",
+ "true negative:\t846\n",
+ "false positive:\t5\n",
+ "false negative:\t83\n",
+ "\n",
+ "accuracy:\t0.912000\n",
+ "precision:\t0.929577\n",
+ "recall:\t\t0.442953\n",
+ "\n",
+ "\n"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "/usr/local/lib/python3.10/dist-packages/sklearn/neural_network/_multilayer_perceptron.py:690: ConvergenceWarning: Stochastic Optimizer: Maximum iterations (200) reached and the optimization hasn't converged yet.\n",
+ " warnings.warn(\n"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "confusion matrix:\n",
+ "[[861 10]\n",
+ " [ 60 69]]\n",
+ "true positive:\t69\n",
+ "true negative:\t861\n",
+ "false positive:\t10\n",
+ "false negative:\t60\n",
+ "\n",
+ "accuracy:\t0.930000\n",
+ "precision:\t0.873418\n",
+ "recall:\t\t0.534884\n",
+ "\n",
+ "\n",
+ "confusion matrix:\n",
+ "[[830 16]\n",
+ " [ 91 63]]\n",
+ "true positive:\t63\n",
+ "true negative:\t830\n",
+ "false positive:\t16\n",
+ "false negative:\t91\n",
+ "\n",
+ "accuracy:\t0.893000\n",
+ "precision:\t0.797468\n",
+ "recall:\t\t0.409091\n",
+ "\n",
+ "\n"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "/usr/local/lib/python3.10/dist-packages/sklearn/neural_network/_multilayer_perceptron.py:690: ConvergenceWarning: Stochastic Optimizer: Maximum iterations (200) reached and the optimization hasn't converged yet.\n",
+ " warnings.warn(\n"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "confusion matrix:\n",
+ "[[837 9]\n",
+ " [ 86 68]]\n",
+ "true positive:\t68\n",
+ "true negative:\t837\n",
+ "false positive:\t9\n",
+ "false negative:\t86\n",
+ "\n",
+ "accuracy:\t0.905000\n",
+ "precision:\t0.883117\n",
+ "recall:\t\t0.441558\n",
+ "\n",
+ "\n",
+ "confusion matrix:\n",
+ "[[855 10]\n",
+ " [ 83 52]]\n",
+ "true positive:\t52\n",
+ "true negative:\t855\n",
+ "false positive:\t10\n",
+ "false negative:\t83\n",
+ "\n",
+ "accuracy:\t0.907000\n",
+ "precision:\t0.838710\n",
+ "recall:\t\t0.385185\n",
+ "\n",
+ "\n",
+ "confusion matrix:\n",
+ "[[849 7]\n",
+ " [ 73 71]]\n",
+ "true positive:\t71\n",
+ "true negative:\t849\n",
+ "false positive:\t7\n",
+ "false negative:\t73\n",
+ "\n",
+ "accuracy:\t0.920000\n",
+ "precision:\t0.910256\n",
+ "recall:\t\t0.493056\n",
+ "\n",
+ "\n",
+ "confusion matrix:\n",
+ "[[861 7]\n",
+ " [ 68 64]]\n",
+ "true positive:\t64\n",
+ "true negative:\t861\n",
+ "false positive:\t7\n",
+ "false negative:\t68\n",
+ "\n",
+ "accuracy:\t0.925000\n",
+ "precision:\t0.901408\n",
+ "recall:\t\t0.484848\n",
+ "\n",
+ "\n",
+ "confusion matrix:\n",
+ "[[857 5]\n",
+ " [ 86 52]]\n",
+ "true positive:\t52\n",
+ "true negative:\t857\n",
+ "false positive:\t5\n",
+ "false negative:\t86\n",
+ "\n",
+ "accuracy:\t0.909000\n",
+ "precision:\t0.912281\n",
+ "recall:\t\t0.376812\n",
+ "\n",
+ "\n",
+ "confusion matrix:\n",
+ "[[862 8]\n",
+ " [ 72 58]]\n",
+ "true positive:\t58\n",
+ "true negative:\t862\n",
+ "false positive:\t8\n",
+ "false negative:\t72\n",
+ "\n",
+ "accuracy:\t0.920000\n",
+ "precision:\t0.878788\n",
+ "recall:\t\t0.446154\n",
+ "\n",
+ "\n",
+ "confusion matrix:\n",
+ "[[847 9]\n",
+ " [ 76 68]]\n",
+ "true positive:\t68\n",
+ "true negative:\t847\n",
+ "false positive:\t9\n",
+ "false negative:\t76\n",
+ "\n",
+ "accuracy:\t0.915000\n",
+ "precision:\t0.883117\n",
+ "recall:\t\t0.472222\n",
+ "\n",
+ "\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "This is model 3: A Decision Tree. The maximum accuracy and Recall achieved over 10-fold CV is 86.7% accuracy and 49.6% precision"
+ ],
+ "metadata": {
+ "id": "DrJ21i6rnmS8"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# 10-fold Cross validation\n",
+ "kfold = KFold(n_splits=10)\n",
+ "\n",
+ "for i, (train_index, test_index) in enumerate(kfold.split(data)):\n",
+ " train = data[train_index, :]\n",
+ " test = data[test_index, :]\n",
+ "\n",
+ " dtree_classifier = DecisionTreeClassifier()\n",
+ " dtree_classifier.fit(train[:, :12], train[:,12])\n",
+ "\n",
+ " output = dtree_classifier.predict(test[:,:12])\n",
+ " print(\"iter %d:\\n------------------\"%(i+1))\n",
+ "\n",
+ " #C_ij = i actual, j predicted\n",
+ " c_matrix = confusion_matrix(test[:,12], output)\n",
+ " print(\"confusion matrix:\")\n",
+ " print(c_matrix)\n",
+ " print(\"true positive:\\t%d\\ntrue negative:\\t%d\\nfalse positive:\\t%d\\nfalse negative:\\t%d\"%(c_matrix[1,1], c_matrix[0,0], c_matrix[0,1], c_matrix[1,0]))\n",
+ " precision = c_matrix[1,1] / (c_matrix[1,1] + c_matrix[0,1])\n",
+ " recall = c_matrix[1,1] / (c_matrix[1,1] + c_matrix[1,0])\n",
+ " print(\"\\naccuracy:\\t%f\"%(np.sum(np.equal(output, test[:,12])) / 1000))\n",
+ " print(\"precision:\\t%f\\nrecall:\\t\\t%f\"%(precision, recall))\n",
+ " print(\"\\n\")"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "_NHWq08yBroB",
+ "outputId": "6dab9505-0938-4676-e910-1d24747b1759"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "iter 1:\n",
+ "------------------\n",
+ "confusion matrix:\n",
+ "[[779 72]\n",
+ " [ 80 69]]\n",
+ "true positive:\t69\n",
+ "true negative:\t779\n",
+ "false positive:\t72\n",
+ "false negative:\t80\n",
+ "\n",
+ "accuracy:\t0.848000\n",
+ "precision:\t0.489362\n",
+ "recall:\t\t0.463087\n",
+ "\n",
+ "\n",
+ "iter 2:\n",
+ "------------------\n",
+ "confusion matrix:\n",
+ "[[767 104]\n",
+ " [ 59 70]]\n",
+ "true positive:\t70\n",
+ "true negative:\t767\n",
+ "false positive:\t104\n",
+ "false negative:\t59\n",
+ "\n",
+ "accuracy:\t0.837000\n",
+ "precision:\t0.402299\n",
+ "recall:\t\t0.542636\n",
+ "\n",
+ "\n",
+ "iter 3:\n",
+ "------------------\n",
+ "confusion matrix:\n",
+ "[[769 77]\n",
+ " [ 85 69]]\n",
+ "true positive:\t69\n",
+ "true negative:\t769\n",
+ "false positive:\t77\n",
+ "false negative:\t85\n",
+ "\n",
+ "accuracy:\t0.838000\n",
+ "precision:\t0.472603\n",
+ "recall:\t\t0.448052\n",
+ "\n",
+ "\n",
+ "iter 4:\n",
+ "------------------\n",
+ "confusion matrix:\n",
+ "[[762 84]\n",
+ " [ 90 64]]\n",
+ "true positive:\t64\n",
+ "true negative:\t762\n",
+ "false positive:\t84\n",
+ "false negative:\t90\n",
+ "\n",
+ "accuracy:\t0.826000\n",
+ "precision:\t0.432432\n",
+ "recall:\t\t0.415584\n",
+ "\n",
+ "\n",
+ "iter 5:\n",
+ "------------------\n",
+ "confusion matrix:\n",
+ "[[776 89]\n",
+ " [ 81 54]]\n",
+ "true positive:\t54\n",
+ "true negative:\t776\n",
+ "false positive:\t89\n",
+ "false negative:\t81\n",
+ "\n",
+ "accuracy:\t0.830000\n",
+ "precision:\t0.377622\n",
+ "recall:\t\t0.400000\n",
+ "\n",
+ "\n",
+ "iter 6:\n",
+ "------------------\n",
+ "confusion matrix:\n",
+ "[[768 88]\n",
+ " [ 76 68]]\n",
+ "true positive:\t68\n",
+ "true negative:\t768\n",
+ "false positive:\t88\n",
+ "false negative:\t76\n",
+ "\n",
+ "accuracy:\t0.836000\n",
+ "precision:\t0.435897\n",
+ "recall:\t\t0.472222\n",
+ "\n",
+ "\n",
+ "iter 7:\n",
+ "------------------\n",
+ "confusion matrix:\n",
+ "[[804 64]\n",
+ " [ 69 63]]\n",
+ "true positive:\t63\n",
+ "true negative:\t804\n",
+ "false positive:\t64\n",
+ "false negative:\t69\n",
+ "\n",
+ "accuracy:\t0.867000\n",
+ "precision:\t0.496063\n",
+ "recall:\t\t0.477273\n",
+ "\n",
+ "\n",
+ "iter 8:\n",
+ "------------------\n",
+ "confusion matrix:\n",
+ "[[780 82]\n",
+ " [ 86 52]]\n",
+ "true positive:\t52\n",
+ "true negative:\t780\n",
+ "false positive:\t82\n",
+ "false negative:\t86\n",
+ "\n",
+ "accuracy:\t0.832000\n",
+ "precision:\t0.388060\n",
+ "recall:\t\t0.376812\n",
+ "\n",
+ "\n",
+ "iter 9:\n",
+ "------------------\n",
+ "confusion matrix:\n",
+ "[[787 83]\n",
+ " [ 80 50]]\n",
+ "true positive:\t50\n",
+ "true negative:\t787\n",
+ "false positive:\t83\n",
+ "false negative:\t80\n",
+ "\n",
+ "accuracy:\t0.837000\n",
+ "precision:\t0.375940\n",
+ "recall:\t\t0.384615\n",
+ "\n",
+ "\n",
+ "iter 10:\n",
+ "------------------\n",
+ "confusion matrix:\n",
+ "[[767 89]\n",
+ " [ 64 80]]\n",
+ "true positive:\t80\n",
+ "true negative:\t767\n",
+ "false positive:\t89\n",
+ "false negative:\t64\n",
+ "\n",
+ "accuracy:\t0.847000\n",
+ "precision:\t0.473373\n",
+ "recall:\t\t0.555556\n",
+ "\n",
+ "\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "plot_data = np.zeros((11,2))\n",
+ "\n",
+ "for i in range(12,1,-1):\n",
+ " pca = PCA(n_components = i)\n",
+ " data_transformed = np.ascontiguousarray(pca.fit_transform(data[:,:i])).astype(float)\n",
+ "\n",
+ " train = data_transformed[:9000,:]\n",
+ " test = data_transformed[9000:,:]\n",
+ "\n",
+ " nn_classifier = MLPClassifier()\n",
+ " nn_classifier.fit(train[:,:-1], data[:9000,-1])\n",
+ " output = nn_classifier.predict(test[:,:-1])\n",
+ "\n",
+ " accuracy = np.sum(np.equal(output, data[9000:,-1])) / 1000\n",
+ " plot_data[12 - i] = [12-i,accuracy]\n",
+ "print()\n",
+ "\n",
+ "plt.plot(plot_data[:,0], plot_data[:,1])\n",
+ "plt.title(\"neural network accuracy as increased dimensions are reduced via PCA\")\n",
+ "plt.xlabel(\"dimensions reduced\")\n",
+ "plt.ylabel(\"accuracy\")\n",
+ "plt.show()"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 472
+ },
+ "id": "W1TtAYJqlk2C",
+ "outputId": "6a8f0005-862b-4ccc-c924-faa5f2ca4fe7"
+ },
+ "execution_count": 15,
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ ""
+ ],
+ "image/png": "\n"
+ },
+ "metadata": {}
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "As can be seen, the accuracy of the model drops off once PCA drops 9 dimensions. i.e. the accuracy drops off when PCA projects the data to 3 dimensions or less"
+ ],
+ "metadata": {
+ "id": "eDLdI9APnJzQ"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [],
+ "metadata": {
+ "id": "0GqOmTCkonxi"
+ },
+ "execution_count": null,
+ "outputs": []
+ }
+ ]
+}
\ No newline at end of file
diff --git a/setup.py b/setup.py
index 31e09f23..22a94e88 100644
--- a/setup.py
+++ b/setup.py
@@ -2,7 +2,7 @@
extras = dict()
-extras['test'] = ['cmake', 'ninja', 'nle>=0.9.0', 'matplotlib>=3.7.1', 'scipy==1.10.0', 'tensorboard>=2.13.0', 'shimmy']
+extras['test'] = ['cmake', 'matplotlib>=3.7.1', 'scipy==1.10.0', 'tensorboard>=2.13.0', 'shimmy']
extras['docs'] = ['sphinx-tabs', 'sphinxcontrib-spelling', 'furo']
extras['all'] = extras['test'] + extras['docs']
diff --git a/syllabus/core/task_interface/task_wrapper.py b/syllabus/core/task_interface/task_wrapper.py
index 1de3970d..e114a2b3 100644
--- a/syllabus/core/task_interface/task_wrapper.py
+++ b/syllabus/core/task_interface/task_wrapper.py
@@ -85,6 +85,24 @@ def __getattr__(self, attr):
return env_attr
+# <<<<<<< HEAD
+# class PettingZooTaskWrapper(TaskWrapper, BaseParallelWraper):
+# def __init__(self, env: pettingzoo.ParallelEnv):
+# super().__init__(env)
+# self.task = None
+
+# @property
+# def agents(self):
+# return self.env.agents
+
+# def __getattr__(self, attr):
+# env_attr = getattr(self.env, attr, None)
+# if env_attr:
+# return env_attr
+
+# def get_current_task(self):
+# return self.current_task
+# =======
class PettingZooTaskWrapper(BaseParallelWrapper):
def __init__(self, env: pettingzoo.ParallelEnv):
super().__init__(env)
@@ -159,3 +177,4 @@ def _task_completion(self, obs, rew, term, trunc, info) -> float:
"""
# return 1.0 if term or trunc else 0.0
return info
+# >>>>>>> b88c2fcba4658545e156188c85f48f0b1e54aab2
diff --git a/syllabus/examples/custom_envs/__init__.py b/syllabus/examples/custom_envs/__init__.py
new file mode 100644
index 00000000..d80c350e
--- /dev/null
+++ b/syllabus/examples/custom_envs/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from custom_envs.obstructedmaze_gamut import ObstructedMazeGamut
diff --git a/syllabus/examples/custom_envs/obstructedmaze_fixedgrid.py b/syllabus/examples/custom_envs/obstructedmaze_fixedgrid.py
new file mode 100644
index 00000000..89f9e53e
--- /dev/null
+++ b/syllabus/examples/custom_envs/obstructedmaze_fixedgrid.py
@@ -0,0 +1,232 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from gym_minigrid.minigrid import *
+from gym_minigrid.roomgrid import RoomGrid
+from gym_minigrid.register import register
+
+class ObstructedMazeEnvFixedGrid(RoomGrid):
+ """
+ A blue ball is hidden in the maze. Doors may be locked,
+ doors may be obstructed by a ball and keys may be hidden in boxes.
+ """
+
+ def __init__(self,
+ num_rows,
+ num_cols,
+ num_rooms_visited,
+ seed=None
+ ):
+ room_size = 7
+ max_steps = 4*num_rooms_visited*room_size**2
+
+ super().__init__(
+ room_size=room_size,
+ num_rows=num_rows,
+ num_cols=num_cols,
+ frame_rows=3,
+ frame_cols=3,
+ max_steps=max_steps,
+ seed=seed
+ )
+
+ def _gen_grid(self, width, height):
+ super()._gen_grid(width, height)
+
+ # Define all possible colors for doors
+ self.door_colors = self._rand_subset(COLOR_NAMES, len(COLOR_NAMES))
+ # Define the color of the ball to pick up
+ self.ball_to_find_color = COLOR_NAMES[0]
+ # Define the color of the balls that obstruct doors
+ self.blocking_ball_color = COLOR_NAMES[1]
+ # Define the color of boxes in which keys are hidden
+ self.box_color = COLOR_NAMES[2]
+
+ self.mission = "pick up the %s ball" % self.ball_to_find_color
+
+ def step(self, action):
+ obs, reward, done, info = super().step(action)
+
+ if action == self.actions.pickup:
+ if self.carrying and self.carrying == self.obj:
+ reward = self._reward()
+ done = True
+
+ return obs, reward, done, info
+
+ def add_door(self, i, j, door_idx=0, color=None, locked=False, key_in_box=False, blocked=False):
+ """
+ Add a door. If the door must be locked, it also adds the key.
+ If the key must be hidden, it is put in a box. If the door must
+ be obstructed, it adds a ball in front of the door.
+ """
+
+ door, door_pos = super().add_door(i, j, door_idx, color, locked=locked)
+
+ if blocked:
+ vec = DIR_TO_VEC[door_idx]
+ blocking_ball = Ball(self.blocking_ball_color) if blocked else None
+ self.grid.set(door_pos[0]-vec[0], door_pos[1]-vec[1], blocking_ball)
+
+ if locked:
+ obj = Key(door.color)
+ if key_in_box:
+ box = Box(self.box_color) if key_in_box else None
+ box.contains = obj
+ obj = box
+ self.place_in_room(i, j, obj)
+
+ return door, door_pos
+
+class ObstructedMaze_1Dlhb(ObstructedMazeEnvFixedGrid):
+ """
+ A blue ball is hidden in a 2x1 maze. A locked door separates
+ rooms. Doors are obstructed by a ball and keys are hidden in boxes.
+ """
+
+ def __init__(self, key_in_box=True, blocked=True, seed=None):
+ self.key_in_box = key_in_box
+ self.blocked = blocked
+
+ super().__init__(
+ num_rows=1,
+ num_cols=2,
+ num_rooms_visited=2,
+ seed=seed
+ )
+
+ def _gen_grid(self, width, height):
+ super()._gen_grid(width, height)
+
+ self.add_door(0, 0, door_idx=0, color=self.door_colors[0],
+ locked=True,
+ key_in_box=self.key_in_box,
+ blocked=self.blocked)
+
+ self.obj, _ = self.add_object(1, 0, "ball", color=self.ball_to_find_color)
+ self.place_agent(0, 0)
+
+class ObstructedMaze_1Dl(ObstructedMaze_1Dlhb):
+ def __init__(self, seed=None):
+ super().__init__(False, False, seed)
+
+class ObstructedMaze_1Dlh(ObstructedMaze_1Dlhb):
+ def __init__(self, seed=None):
+ super().__init__(True, False, seed)
+
+class ObstructedMaze_Full(ObstructedMazeEnvFixedGrid):
+ """
+ A blue ball is hidden in one of the 4 corners of a 3x3 maze. Doors
+ are locked, doors are obstructed by a ball and keys are hidden in
+ boxes.
+ """
+
+ def __init__(self, agent_room=(1, 1), key_in_box=True, blocked=True,
+ num_quarters=4, num_rooms_visited=25, seed=None):
+ self.agent_room = agent_room
+ self.key_in_box = key_in_box
+ self.blocked = blocked
+ self.num_quarters = num_quarters
+
+ super().__init__(
+ num_rows=3,
+ num_cols=3,
+ num_rooms_visited=num_rooms_visited,
+ seed=seed
+ )
+
+ def _gen_grid(self, width, height):
+ super()._gen_grid(width, height)
+
+ middle_room = (1, 1)
+ # Define positions of "side rooms" i.e. rooms that are neither
+ # corners nor the center.
+ side_rooms = [(2, 1), (1, 2), (0, 1), (1, 0)][:self.num_quarters]
+ for i in range(len(side_rooms)):
+ side_room = side_rooms[i]
+
+ # Add a door between the center room and the side room
+ self.add_door(*middle_room, door_idx=i, color=self.door_colors[i], locked=False)
+
+ for k in [-1, 1]:
+ # Add a door to each side of the side room
+ self.add_door(*side_room, locked=True,
+ door_idx=(i+k)%4,
+ color=self.door_colors[(i+k)%len(self.door_colors)],
+ key_in_box=self.key_in_box,
+ blocked=self.blocked)
+
+ corners = [(2, 0), (2, 2), (0, 2), (0, 0)][:self.num_quarters]
+ ball_room = self._rand_elem(corners)
+
+ self.obj, _ = self.add_object(*ball_room, "ball", color=self.ball_to_find_color)
+ self.place_agent(*self.agent_room)
+
+class ObstructedMaze_2Dl(ObstructedMaze_Full):
+ def __init__(self, seed=None):
+ super().__init__((2, 1), False, False, 1, 4, seed)
+
+class ObstructedMaze_2Dlh(ObstructedMaze_Full):
+ def __init__(self, seed=None):
+ super().__init__((2, 1), True, False, 1, 4, seed)
+
+
+class ObstructedMaze_2Dlhb(ObstructedMaze_Full):
+ def __init__(self, seed=None):
+ super().__init__((2, 1), True, True, 1, 4, seed)
+
+class ObstructedMaze_1Q(ObstructedMaze_Full):
+ def __init__(self, seed=None):
+ super().__init__((1, 1), True, True, 1, 5, seed)
+
+class ObstructedMaze_2Q(ObstructedMaze_Full):
+ def __init__(self, seed=None):
+ super().__init__((1, 1), True, True, 2, 11, seed)
+
+register(
+ id="MiniGrid-ObstructedMaze-1Dl-fixed_grid-v0",
+ entry_point=f"{__name__}:ObstructedMaze_1Dl"
+)
+
+register(
+ id="MiniGrid-ObstructedMaze-1Dlh-fixed_grid-v0",
+ entry_point=f"{__name__}:ObstructedMaze_1Dlh"
+)
+
+register(
+ id="MiniGrid-ObstructedMaze-1Dlhb-fixed_grid-v0",
+ entry_point=f"{__name__}:ObstructedMaze_1Dlhb"
+)
+
+register(
+ id="MiniGrid-ObstructedMaze-2Dl-fixed_grid-v0",
+ entry_point=f"{__name__}:ObstructedMaze_2Dl"
+)
+
+register(
+ id="MiniGrid-ObstructedMaze-2Dlh-fixed_grid-v0",
+ entry_point=f"{__name__}:ObstructedMaze_2Dlh"
+)
+
+register(
+ id="MiniGrid-ObstructedMaze-2Dlhb-fixed_grid-v0",
+ entry_point=f"{__name__}:ObstructedMaze_2Dlhb"
+)
+
+register(
+ id="MiniGrid-ObstructedMaze-1Q-fixed_grid-v0",
+ entry_point=f"{__name__}:ObstructedMaze_1Q"
+)
+
+register(
+ id="MiniGrid-ObstructedMaze-2Q-fixed_grid-v0",
+ entry_point=f"{__name__}:ObstructedMaze_2Q"
+)
+
+register(
+ id="MiniGrid-ObstructedMaze-Full-fixed_grid-v0",
+ entry_point=f"{__name__}:ObstructedMaze_Full"
+)
\ No newline at end of file
diff --git a/syllabus/examples/custom_envs/obstructedmaze_gamut.py b/syllabus/examples/custom_envs/obstructedmaze_gamut.py
new file mode 100644
index 00000000..56efa0d2
--- /dev/null
+++ b/syllabus/examples/custom_envs/obstructedmaze_gamut.py
@@ -0,0 +1,185 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import gym
+from gym_minigrid.register import register
+
+from custom_envs.obstructedmaze_fixedgrid import ObstructedMazeEnvFixedGrid
+
+
+ALL_SUBENVS = [
+ 'MiniGrid-ObstructedMaze-1Dl-fixed_grid-v0',
+ 'MiniGrid-ObstructedMaze-1Dlh-fixed_grid-v0',
+ 'MiniGrid-ObstructedMaze-1Dlhb-fixed_grid-v0',
+ 'MiniGrid-ObstructedMaze-2Dl-fixed_grid-v0',
+ 'MiniGrid-ObstructedMaze-2Dlh-fixed_grid-v0',
+ 'MiniGrid-ObstructedMaze-2Dlhb-fixed_grid-v0',
+ 'MiniGrid-ObstructedMaze-1Q-fixed_grid-v0',
+ 'MiniGrid-ObstructedMaze-2Q-fixed_grid-v0',
+ 'MiniGrid-ObstructedMaze-Full-fixed_grid-v0'
+]
+
+TILE_PIXELS = 32
+
+
+class ObstructedMazeGamut(gym.Env):
+ def __init__(self, distribution='easy', max_difficulty=None, seed=1337):
+
+ self.distribution = distribution
+ if distribution == 'easy':
+ self.max_difficulty = 3
+ elif distribution == 'medium':
+ self.max_difficulty = 6
+ elif distribution == 'hard':
+ self.max_difficulty = 9
+ else:
+ raise ValueError(f'Unsupported distribution {distribution}.')
+
+ if max_difficulty is not None:
+ self.max_difficulty = max_difficulty
+
+ self.subenvs = []
+ for env_name in ALL_SUBENVS[:self.max_difficulty]:
+ self.subenvs.append(gym.make(env_name))
+
+ self.num_subenvs = len(self.subenvs)
+
+ self.seed(seed)
+ self.reset()
+
+ @property
+ def actions(self):
+ return self.env.actions
+
+ @property
+ def agent_view_size(self):
+ return self.env.agent_view_size
+
+ @property
+ def reward_range(self):
+ return self.env.reward_range
+
+ @property
+ def window(self):
+ return self.env.window
+
+ @property
+ def width(self):
+ return self.env.width
+
+ @property
+ def height(self):
+ return self.env.height
+
+ @property
+ def grid(self):
+ return self.env.grid
+
+ @property
+ def max_steps(self):
+ return self.env.max_steps
+
+ @property
+ def see_through_walls(self):
+ return self.env.see_through_walls
+
+ @property
+ def agent_pos(self):
+ return self.env.agent_pos
+
+ @property
+ def agent_dir(self):
+ return self.env.agent_dir
+
+ @property
+ def step_count(self):
+ return self.env.step_count
+
+ @property
+ def carrying(self):
+ return self.env.carrying
+
+ @property
+ def observation_space(self):
+ return self.env.observation_space
+
+ @property
+ def action_space(self):
+ return self.env.action_space
+
+ @property
+ def steps_remaining(self):
+ return self.env.steps_remaining
+
+ def __str__(self):
+ return self.env.__str__()
+
+ def reset(self):
+ return self.env.reset()
+
+ def seed(self, seed=1337):
+ env_index = seed % self.num_subenvs
+ self.env = self.subenvs[env_index]
+ self.env.seed(seed)
+
+ def hash(self, size=16):
+ return self.env.hash(size)
+
+ def relative_coords(self, x, y):
+ return self.env.relative_coords(x, y)
+
+ def in_view(self, x, y):
+ return self.env.in_view(x, y)
+
+ def agent_sees(self, x, y):
+ return self.env.agent_sees(x, y)
+
+ def step(self, action):
+ return self.env.step(action)
+
+ def gen_obs_grid(self):
+ return self.env.gen_obs_grid()
+
+ def gen_obs(self):
+ return self.env.gen_obs()
+
+ def get_obs_render(self, obs, tile_size=TILE_PIXELS//2):
+ return self.env.get_obs_render(obs, tile_size)
+
+ def render(self, mode='human', close=False, highlight=True, tile_size=TILE_PIXELS):
+ return self.env.render(mode, close, highlight, tile_size)
+
+ def close(self):
+ return self.env.close()
+
+
+class ObstructedMazeGamut_Easy(ObstructedMazeGamut):
+ def __init__(self, seed=1337):
+ super().__init__(distribution='easy', seed=seed)
+
+class ObstructedMazeGamut_Medium(ObstructedMazeGamut):
+ def __init__(self, seed=1337):
+ super().__init__(distribution='medium', seed=seed)
+
+class ObstructedMazeGamut_Hard(ObstructedMazeGamut):
+ def __init__(self, seed=1337):
+ super().__init__(distribution='hard', seed=seed)
+
+
+register(
+ id="MiniGrid-ObstructedMazeGamut-Easy-v0",
+ entry_point=f"{__name__}:ObstructedMazeGamut_Easy"
+)
+
+register(
+ id="MiniGrid-ObstructedMazeGamut-Medium-v0",
+ entry_point=f"{__name__}:ObstructedMazeGamut_Medium"
+)
+
+register(
+ id="MiniGrid-ObstructedMazeGamut-Hard-v0",
+ entry_point=f"{__name__}:ObstructedMazeGamut_Hard"
+)
diff --git a/syllabus/examples/models/__pycache__/__init__.cpython-38.pyc b/syllabus/examples/models/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 00000000..214f3fd7
Binary files /dev/null and b/syllabus/examples/models/__pycache__/__init__.cpython-38.pyc differ
diff --git a/syllabus/examples/models/__pycache__/__init__.cpython-39.pyc b/syllabus/examples/models/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 00000000..eb388eeb
Binary files /dev/null and b/syllabus/examples/models/__pycache__/__init__.cpython-39.pyc differ
diff --git a/syllabus/examples/models/__pycache__/minigrid_model.cpython-38.pyc b/syllabus/examples/models/__pycache__/minigrid_model.cpython-38.pyc
new file mode 100644
index 00000000..e6f8c6bb
Binary files /dev/null and b/syllabus/examples/models/__pycache__/minigrid_model.cpython-38.pyc differ
diff --git a/syllabus/examples/models/__pycache__/minigrid_model.cpython-39.pyc b/syllabus/examples/models/__pycache__/minigrid_model.cpython-39.pyc
new file mode 100644
index 00000000..34242b83
Binary files /dev/null and b/syllabus/examples/models/__pycache__/minigrid_model.cpython-39.pyc differ
diff --git a/syllabus/examples/models/__pycache__/minigrid_model_verma.cpython-38.pyc b/syllabus/examples/models/__pycache__/minigrid_model_verma.cpython-38.pyc
new file mode 100644
index 00000000..c4275b5b
Binary files /dev/null and b/syllabus/examples/models/__pycache__/minigrid_model_verma.cpython-38.pyc differ
diff --git a/syllabus/examples/models/__pycache__/minigrid_model_verma.cpython-39.pyc b/syllabus/examples/models/__pycache__/minigrid_model_verma.cpython-39.pyc
new file mode 100644
index 00000000..a54ebc4e
Binary files /dev/null and b/syllabus/examples/models/__pycache__/minigrid_model_verma.cpython-39.pyc differ
diff --git a/syllabus/examples/models/__pycache__/procgen_model.cpython-38.pyc b/syllabus/examples/models/__pycache__/procgen_model.cpython-38.pyc
new file mode 100644
index 00000000..97ec398d
Binary files /dev/null and b/syllabus/examples/models/__pycache__/procgen_model.cpython-38.pyc differ
diff --git a/syllabus/examples/models/__pycache__/procgen_model.cpython-39.pyc b/syllabus/examples/models/__pycache__/procgen_model.cpython-39.pyc
new file mode 100644
index 00000000..c71f46c1
Binary files /dev/null and b/syllabus/examples/models/__pycache__/procgen_model.cpython-39.pyc differ
diff --git a/syllabus/examples/models/minigrid_model_verma.py b/syllabus/examples/models/minigrid_model_verma.py
new file mode 100644
index 00000000..bb4efbd4
--- /dev/null
+++ b/syllabus/examples/models/minigrid_model_verma.py
@@ -0,0 +1,185 @@
+import numpy as np
+import torch
+import torch.nn as nn
+
+def init(module, weight_init, bias_init, gain=1):
+ weight_init(module.weight.data, gain=gain)
+ bias_init(module.bias.data)
+ return module
+
+init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0))
+
+init_relu_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), nn.init.calculate_gain('relu'))
+
+init_tanh_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), np.sqrt(2))
+
+class FixedCategorical(torch.distributions.Categorical):
+ """
+ Categorical distribution object
+ """
+ def sample(self):
+ return super().sample().unsqueeze(-1)
+
+ def log_probs(self, actions):
+ return (
+ super()
+ .log_prob(actions.squeeze(-1))
+ .view(actions.size(0), -1)
+ .sum(-1)
+ .unsqueeze(-1)
+ )
+
+ def mode(self):
+ return self.probs.argmax(dim=-1, keepdim=True)
+
+class Categorical(nn.Module):
+ """
+ Categorical distribution (NN module)
+ """
+ def __init__(self, num_inputs, num_outputs):
+ super(Categorical, self).__init__()
+
+ init_ = lambda m: init(
+ m,
+ nn.init.orthogonal_,
+ lambda x: nn.init.constant_(x, 0),
+ gain=0.01)
+
+ self.linear = init_(nn.Linear(num_inputs, num_outputs))
+
+ def forward(self, x):
+ x = self.linear(x)
+ return FixedCategorical(logits=x)
+
+def apply_init_(modules):
+ """
+ Initialize NN modules
+ """
+ for m in modules:
+ if isinstance(m, nn.Conv2d):
+ nn.init.xavier_uniform_(m.weight)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+class MinigridPolicyVerma(nn.Module):
+ """
+ Actor-Critic module
+ """
+ def __init__(self, obs_shape, num_actions, arch='small', base_kwargs=None):
+ super(MinigridPolicyVerma, self).__init__()
+
+ if base_kwargs is None:
+ base_kwargs = {}
+
+ final_channels = 32 if arch == 'small' else 64
+
+ self.image_conv = nn.Sequential(
+ nn.Conv2d(3, 16, (2, 2)),
+ nn.ReLU(),
+ nn.MaxPool2d((2, 2)),
+ nn.Conv2d(16, 32, (2, 2)),
+ nn.ReLU(),
+ nn.Conv2d(32, final_channels, (2, 2)),
+ nn.ReLU()
+ )
+ n = obs_shape[-2]
+ m = obs_shape[-1]
+ self.image_embedding_size = ((n-1)//2-2)*((m-1)//2-2)*final_channels
+ self.embedding_size = self.image_embedding_size
+
+ # Define actor's model
+ self.actor_base = nn.Sequential(
+ init_tanh_(nn.Linear(self.embedding_size, 64)),
+ nn.Tanh(),
+ )
+
+ # Define critic's model
+ self.critic = nn.Sequential(
+ init_tanh_(nn.Linear(self.embedding_size, 64)),
+ nn.Tanh(),
+ init_(nn.Linear(64, 1))
+ )
+
+ self.dist = Categorical(64, num_actions)
+
+ apply_init_(self.modules())
+
+ self.train()
+
+ @property
+ def is_recurrent(self):
+ return False
+
+ @property
+ def recurrent_hidden_state_size(self):
+ """Size of rnn_hx."""
+ return 1
+
+ def forward(self, inputs, rnn_hxs, masks):
+ raise NotImplementedError
+
+ def act(self, inputs, deterministic=False):
+ x = inputs
+ x = self.image_conv(x)
+ x = x.flatten(1, -1)
+ actor_features = self.actor_base(x)
+ value = self.critic(x)
+ dist = self.dist(actor_features)
+
+ if deterministic:
+ action = dist.mode()
+ else:
+ action = dist.sample()
+
+ # action_log_probs = dist.log_probs(action)
+ action_log_dist = dist.logits
+ dist_entropy = dist.entropy().mean()
+
+ return action, action_log_dist, dist_entropy, value
+
+ def get_value(self, inputs, rnn_hxs, masks):
+ x = inputs
+ x = self.image_conv(x)
+ x = x.flatten(1, -1)
+ return self.critic(x)
+
+ def evaluate_actions(self, inputs, rnn_hxs, masks, action):
+ x = inputs
+ x = self.image_conv(x)
+ x = x.flatten(1, -1)
+ actor_features = self.actor_base(x)
+ value = self.critic(x)
+ dist = self.dist(actor_features)
+
+ action_log_probs = dist.log_probs(action)
+ dist_entropy = dist.entropy().mean()
+
+ return value, action_log_probs, dist_entropy, rnn_hxs
+
+class MinigridAgentVerma(MinigridPolicyVerma):
+ def get_value(self, x):
+ x = self.image_conv(x)
+ x = x.flatten(1, -1)
+ return self.critic(x)
+
+ def get_action_and_value(self, x, action=None, full_log_probs=False):
+ x = self.image_conv(x)
+ x = x.flatten(1, -1)
+ actor_features = self.actor_base(x)
+ value = self.critic(x)
+ dist = self.dist(actor_features)
+
+ action = torch.squeeze(dist.sample())
+
+ action_log_probs = torch.squeeze(dist.log_probs(action))
+ dist_entropy = dist.entropy()
+
+ if full_log_probs:
+ log_probs = torch.log(dist.probs)
+ return action, action_log_probs, dist_entropy, value, log_probs
+
+ return action, action_log_probs, dist_entropy, value
diff --git a/syllabus/examples/task_wrappers/minigrid_task_wrapper_verma.py b/syllabus/examples/task_wrappers/minigrid_task_wrapper_verma.py
new file mode 100644
index 00000000..cf440903
--- /dev/null
+++ b/syllabus/examples/task_wrappers/minigrid_task_wrapper_verma.py
@@ -0,0 +1,51 @@
+import gymnasium as gym
+import numpy as np
+from syllabus.core import TaskWrapper
+from syllabus.task_space import TaskSpace
+from gym_minigrid.wrappers import FullyObsWrapper, ImgObsWrapper
+from shimmy.openai_gym_compatibility import GymV21CompatibilityV0
+from gymnasium.spaces import Box
+
+class MinigridTaskWrapperVerma(TaskWrapper):
+ def __init__(self, env: gym.Env, env_id, seed=0):
+ super().__init__(env)
+ self.env.unwrapped.seed(seed)
+ self.task_space = TaskSpace(gym.spaces.Discrete(200), list(np.arange(0, 200)))
+ self.env_id = env_id
+ self.task = seed
+ self.episode_return = 0
+ m, n, c = self.env.observation_space.shape
+ self.observation_space = Box(
+ self.observation_space.low[0, 0, 0],
+ self.observation_space.high[0, 0, 0],
+ [c, m, n],
+ dtype=self.observation_space.dtype)
+
+ def observation(self, obs):
+ obs = obs.transpose(2, 0, 1)
+ return obs
+
+ def reset(self, new_task=None, **kwargs):
+ self.episode_return = 0.0
+ if new_task is not None:
+ self.change_task(new_task)
+ obs, info = self.env.reset(**kwargs)
+ return self.observation(obs), info
+
+ def change_task(self, new_task: int):
+ """
+ Change task by directly editing environment class.
+
+ Ignores requests for unknown tasks or task changes outside of a reset.
+ """
+ seed = int(new_task)
+ self.task = seed
+ self.seed(seed)
+
+ def seed(self, seed):
+ self.env.unwrapped.seed(int(seed))
+
+ def step(self, action):
+ obs, rew, term, trunc, info = self.env.step(action)
+ self.episode_return += rew
+ return self.observation(obs), rew, term, trunc, info
diff --git a/syllabus/examples/training_scripts/.gitignore b/syllabus/examples/training_scripts/.gitignore
new file mode 100644
index 00000000..49e0461d
--- /dev/null
+++ b/syllabus/examples/training_scripts/.gitignore
@@ -0,0 +1,3 @@
+command.txt
+wandb
+requirements.txt
diff --git a/syllabus/examples/training_scripts/test_minigrid_wrapper.py b/syllabus/examples/training_scripts/test_minigrid_wrapper.py
new file mode 100644
index 00000000..80c19e78
--- /dev/null
+++ b/syllabus/examples/training_scripts/test_minigrid_wrapper.py
@@ -0,0 +1,553 @@
+import argparse
+import os, sys
+import random
+import time
+from collections import deque
+from distutils.util import strtobool
+
+import gym as openai_gym
+import gymnasium as gym
+import numpy as np
+import procgen # noqa: F401
+from procgen import ProcgenEnv
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from shimmy.openai_gym_compatibility import GymV21CompatibilityV0
+from torch.utils.tensorboard import SummaryWriter
+
+from syllabus.core import MultiProcessingSyncWrapper, make_multiprocessing_curriculum
+from syllabus.curricula import CentralizedPrioritizedLevelReplay, DomainRandomization, LearningProgressCurriculum, SequentialCurriculum
+from syllabus.examples.models import ProcgenAgent, MinigridAgent
+from syllabus.examples.task_wrappers import ProcgenTaskWrapper
+from syllabus.examples.utils.vecenv import VecMonitor, VecNormalize, VecExtractDictObs
+
+from gym_minigrid.wrappers import FullyObsWrapper, ImgObsWrapper
+sys.path.append("/data/averma/MARL/Syllabus/syllabus/examples/task_wrappers")
+sys.path.append("/data/averma/MARL/Syllabus/syllabus/examples/models")
+from minigrid_model_verma import *
+from minigrid_task_wrapper_verma import *
+import torch.nn as nn
+
+
+def parse_args():
+ # fmt: off
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"),
+ help="the name of this experiment")
+ parser.add_argument("--seed", type=int, default=1,
+ help="seed of the experiment")
+ parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
+ help="if toggled, `torch.backends.cudnn.deterministic=False`")
+ parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
+ help="if toggled, cuda will be enabled by default")
+ parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
+ help="if toggled, this experiment will be tracked with Weights and Biases")
+ parser.add_argument("--wandb-project-name", type=str, default="syllabus",
+ help="the wandb's project name")
+ parser.add_argument("--wandb-entity", type=str, default=None,
+ help="the entity (team) of wandb's project")
+ parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
+ help="weather to capture videos of the agent performances (check out `videos` folder)")
+ parser.add_argument("--logging-dir", type=str, default=".",
+ help="the base directory for logging and wandb storage.")
+
+ # Algorithm specific arguments
+ parser.add_argument("--env-id", type=str, default="starpilot",
+ help="the id of the environment")
+ parser.add_argument("--total-timesteps", type=int, default=int(25e6),
+ help="total timesteps of the experiments")
+ parser.add_argument("--learning-rate", type=float, default=5e-4,
+ help="the learning rate of the optimizer")
+ parser.add_argument("--num-envs", type=int, default=64,
+ help="the number of parallel game environments")
+ parser.add_argument("--num-steps", type=int, default=256,
+ help="the number of steps to run in each environment per policy rollout")
+ parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
+ help="Toggle learning rate annealing for policy and value networks")
+ parser.add_argument("--gae", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
+ help="Use GAE for advantage computation")
+ parser.add_argument("--gamma", type=float, default=0.999,
+ help="the discount factor gamma")
+ parser.add_argument("--gae-lambda", type=float, default=0.95,
+ help="the lambda for the general advantage estimation")
+ parser.add_argument("--num-minibatches", type=int, default=8,
+ help="the number of mini-batches")
+ parser.add_argument("--update-epochs", type=int, default=3,
+ help="the K epochs to update the policy")
+ parser.add_argument("--norm-adv", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
+ help="Toggles advantages normalization")
+ parser.add_argument("--clip-coef", type=float, default=0.2,
+ help="the surrogate clipping coefficient")
+ parser.add_argument("--clip-vloss", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
+ help="Toggles whether or not to use a clipped loss for the value function, as per the paper.")
+ parser.add_argument("--ent-coef", type=float, default=0.01,
+ help="coefficient of the entropy")
+ parser.add_argument("--vf-coef", type=float, default=0.5,
+ help="coefficient of the value function")
+ parser.add_argument("--max-grad-norm", type=float, default=0.5,
+ help="the maximum norm for the gradient clipping")
+ parser.add_argument("--target-kl", type=float, default=None,
+ help="the target KL divergence threshold")
+
+ # Procgen arguments
+ parser.add_argument("--full-dist", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
+ help="Train on full distribution of levels.")
+
+ # Curriculum arguments
+ parser.add_argument("--curriculum", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
+ help="if toggled, this experiment will use curriculum learning")
+ parser.add_argument("--curriculum-method", type=str, default="plr",
+ help="curriculum method to use")
+ parser.add_argument("--num-eval-episodes", type=int, default=10,
+ help="the number of episodes to evaluate the agent on after each policy update.")
+
+ args = parser.parse_args()
+ args.batch_size = int(args.num_envs * args.num_steps)
+ args.minibatch_size = int(args.batch_size // args.num_minibatches)
+ # fmt: on
+ return args
+
+
+PROCGEN_RETURN_BOUNDS = {
+ "coinrun": (5, 10),
+ "starpilot": (2.5, 64),
+ "caveflyer": (3.5, 12),
+ "dodgeball": (1.5, 19),
+ "fruitbot": (-1.5, 32.4),
+ "chaser": (0.5, 13),
+ "miner": (1.5, 13),
+ "jumper": (3, 10),
+ "leaper": (3, 10),
+ "maze": (5, 10),
+ "bigfish": (1, 40),
+ "heist": (3.5, 10),
+ "climber": (2, 12.6),
+ "plunder": (4.5, 30),
+ "ninja": (3.5, 10),
+ "bossfight": (0.5, 13),
+}
+
+
+def make_env_minigrid(env_name, seed, curriculum=None):
+ def thunk():
+ env = openai_gym.make(env_name)
+ if curriculum is not None:
+ env = FullyObsWrapper(env)
+ env = ImgObsWrapper(env)
+ env = GymV21CompatibilityV0(env=env)
+ env = MinigridTaskWrapperVerma(env=env, env_id=env_name, seed=seed)
+ env = MultiProcessingSyncWrapper(
+ env,
+ curriculum.get_components(),
+ update_on_step=False,
+ task_space=env.task_space,
+ )
+ else:
+ env = GymV21CompatibilityV0(env=env)
+ return env
+
+ return thunk
+
+def wrap_vecenv(vecenv):
+ vecenv.is_vector_env = True
+ vecenv = VecMonitor(venv=vecenv, filename=None, keep_buf=100)
+ vecenv = VecNormalize(venv=vecenv, ob=False, ret=True)
+ return vecenv
+
+def level_replay_evaluate_minigrid(
+ env_name,
+ policy,
+ num_episodes,
+ device,
+ num_levels=0
+):
+ policy.eval()
+ eval_envs = gym.vector.AsyncVectorEnv(
+ [
+ make_env_minigrid(
+ env_name,
+ args.seed + i,
+ curriculum=curriculum if args.curriculum else None
+ )
+ # for i in range(args.num_envs)
+ for i in range(num_episodes)
+ ]
+ )
+ eval_envs = wrap_vecenv(eval_envs)
+ eval_obs, _ = eval_envs.reset()
+ eval_episode_rewards = [-1] * num_episodes
+
+ while -1 in eval_episode_rewards:
+ with torch.no_grad():
+ eval_action, _, _, _ = policy.get_action_and_value(torch.Tensor(eval_obs).to(device))
+
+ eval_obs, _, truncs, terms, infos = eval_envs.step(eval_action.cpu().numpy())
+ # len(infos) = 64
+ # num_episodes = 10
+ # print("info length: %d"%len(infos))
+ # print("num_episode length: %d"%num_episodes)
+ sys.stdout.flush()
+ for i, info in enumerate(infos):
+ if 'episode' in info.keys() and eval_episode_rewards[i] == -1:
+ eval_episode_rewards[i] = info['episode']['r']
+ print(f"level replay eval works! {eval_episode_rewards[i]}")
+
+ # print(eval_episode_rewards)
+ mean_returns = np.mean(eval_episode_rewards)
+ stddev_returns = np.std(eval_episode_rewards)
+ # env_min, env_max = PROCGEN_RETURN_BOUNDS[args.env_id]
+ env_min = 0
+ env_max = 1
+ normalized_mean_returns = (mean_returns - env_min) / (env_max - env_min)
+ policy.train()
+ return mean_returns, stddev_returns, normalized_mean_returns
+
+
+def make_value_fn():
+ def get_value(obs):
+ obs = np.array(obs)
+ with torch.no_grad():
+ return agent.get_value(torch.Tensor(obs).to(device))
+ return get_value
+
+def print_values(obj):
+ describer = obj.__dict__
+ for key in describer.keys():
+ print(f"{key}: {describer[key]}")
+ print()
+
+
+if __name__ == "__main__":
+
+
+ args = parse_args()
+ env_name = "MiniGrid-MultiRoom-N4-Random-v0"
+ args.env_id = env_name
+ run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
+ random.seed(args.seed)
+ np.random.seed(args.seed)
+ torch.manual_seed(args.seed)
+ torch.backends.cudnn.deterministic = args.torch_deterministic
+ device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
+
+ print("Device:", device)
+
+ if args.track:
+ import wandb
+
+ wandb.init(
+ project=args.wandb_project_name,
+ entity=args.wandb_entity,
+ sync_tensorboard=True,
+ config=vars(args),
+ name=run_name,
+ monitor_gym=True,
+ save_code=True,
+ dir=args.logging_dir
+ )
+
+ # Curriculum setup
+ curriculum = None
+
+ writer = SummaryWriter(os.path.join(args.logging_dir, "./runs/{run_name}"))
+ writer.add_text(
+ "hyperparameters",
+ "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
+ )
+
+ if args.curriculum:
+ print("args:\n--------------")
+ print(f"{args}\n-------------\n")
+
+ # sample_env = openai_gym.make(f"procgen-{args.env_id}-v0")
+ # sample_env = GymV21CompatibilityV0(env=sample_env)
+ # sample_env = ProcgenTaskWrapper(sample_env, args.env_id, seed=args.seed)
+
+ sample_env = openai_gym.make(env_name)
+ sample_env = FullyObsWrapper(sample_env)
+ sample_env = ImgObsWrapper(sample_env)
+ sample_env = GymV21CompatibilityV0(env=sample_env)
+ sample_env = MinigridTaskWrapperVerma(sample_env, args.env_id, seed=args.seed)
+
+ print(f"has curriculum: {args.curriculum}")
+
+ if args.curriculum_method == "plr":
+ print("Using prioritized level replay.")
+ task_sampler_kwargs_dict = {"strategy": "value_l1", "temperature":0.1, "staleness_coef":0.3}
+ curriculum = CentralizedPrioritizedLevelReplay(
+ sample_env.task_space,
+ num_steps=args.num_steps,
+ num_processes=args.num_envs,
+ gamma=args.gamma,
+ gae_lambda=args.gae_lambda,
+ task_sampler_kwargs_dict=task_sampler_kwargs_dict
+ )
+ # elif args.curriculum_method == "dr":
+ # print("Using domain randomization.")
+ # curriculum = DomainRandomization(sample_env.task_space)
+ # elif args.curriculum_method == "lp":
+ # print("Using learning progress.")
+ # curriculum = LearningProgressCurriculum(sample_env.task_space)
+ # elif args.curriculum_method == "sq":
+ # print("Using sequential curriculum.")
+ # curricula = []
+ # stopping = []
+ # for i in range(199):
+ # curricula.append(i + 1)
+ # stopping.append("steps>=50000")
+ # curricula.append(list(range(i + 1)))
+ # stopping.append("steps>=50000")
+ # curriculum = SequentialCurriculum(curricula, stopping[:-1], sample_env.task_space)
+ else:
+ raise ValueError(f"Unknown curriculum method {args.curriculum_method}")
+ curriculum = make_multiprocessing_curriculum(curriculum)
+ del sample_env
+
+ # env setup
+ print("Creating env")
+
+ envs = gym.vector.AsyncVectorEnv(
+ [
+ make_env_minigrid(
+ env_name,
+ args.seed + i,
+ curriculum=curriculum if args.curriculum else None
+ )
+ for i in range(args.num_envs)
+ ]
+ )
+ envs = wrap_vecenv(envs)
+ next_obs, _ = envs.reset()
+ assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"
+
+ agent = MinigridAgentVerma(
+ envs.single_observation_space.shape,
+ envs.single_action_space.n,
+ arch="large",
+ base_kwargs={'recurrent': False, 'hidden_size': 256}
+ ).to(device)
+ optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)
+
+ # eval_envs = gym.vector.AsyncVectorEnv(
+ # [
+ # make_env_minigrid(
+ # env_name,
+ # args.seed + i,
+ # curriculum=curriculum if args.curriculum else None
+ # )
+ # for i in range(args.num_envs)
+ # ]
+ # )
+ #
+ # eval_envs = wrap_vecenv(eval_envs)
+ # eval_obs, _ = eval_envs.reset()
+ # with torch.no_grad():
+ # eval_action, _, _, _ = agent.get_action_and_value(torch.Tensor(eval_obs).to(device))
+ # eval_obs, _, truncs, terms, infos = eval_envs.step(eval_action.cpu().numpy())
+ # print(len(eval_obs))
+ # print(len(infos))
+ # print(args.num_envs)
+ # print(args.num_eval_episodes)
+
+
+ # ALGO Logic: Storage setup
+ obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device)
+ actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)
+ logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device)
+ rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
+ dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
+ values = torch.zeros((args.num_steps, args.num_envs)).to(device)
+
+ # TRY NOT TO MODIFY: start the game
+ global_step = 0
+ start_time = time.time()
+ next_obs, _ = envs.reset()
+ next_obs = torch.Tensor(next_obs).to(device)
+ next_done = torch.zeros(args.num_envs).to(device)
+ num_updates = args.total_timesteps // args.batch_size
+ episode_rewards = deque(maxlen=10)
+ completed_episodes = 0
+
+ for update in range(1, num_updates + 1):
+ # Annealing the rate if instructed to do so.
+ if args.anneal_lr:
+ frac = 1.0 - (update - 1.0) / num_updates
+ lrnow = frac * args.learning_rate
+ optimizer.param_groups[0]["lr"] = lrnow
+
+ for step in range(0, args.num_steps):
+ global_step += 1 * args.num_envs
+ obs[step] = next_obs
+ dones[step] = next_done
+
+ # ALGO LOGIC: action logic
+ with torch.no_grad():
+ action, logprob, _, value = agent.get_action_and_value(next_obs)
+ values[step] = value.flatten()
+ actions[step] = action
+ logprobs[step] = logprob
+
+ # TRY NOT TO MODIFY: execute the game and log data.
+ next_obs, reward, term, trunc, info = envs.step(action.cpu().numpy())
+ done = np.logical_or(term, trunc)
+ rewards[step] = torch.tensor(reward).to(device).view(-1)
+ next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(device)
+ completed_episodes += sum(done)
+
+ for item in info:
+ if "episode" in item.keys():
+ episode_rewards.append(item['episode']['r'])
+ print(f"global_step={global_step}, episodic_return={item['episode']['r']}")
+ writer.add_scalar("charts/episodic_return", item["episode"]["r"], global_step)
+ writer.add_scalar("charts/episodic_length", item["episode"]["l"], global_step)
+ if curriculum is not None:
+ curriculum.log_metrics(writer, global_step)
+ break
+
+ # Syllabus curriculum update
+ if args.curriculum and args.curriculum_method == "plr":
+ with torch.no_grad():
+ next_value = agent.get_value(next_obs)
+ tasks = envs.get_attr("task")
+
+ update = {
+ "update_type": "on_demand",
+ "metrics": {
+ "value": value,
+ "next_value": next_value,
+ "rew": reward,
+ "dones": done,
+ "tasks": tasks,
+ },
+ }
+ curriculum.update(update)
+
+ # bootstrap value if not done
+ with torch.no_grad():
+ next_value = agent.get_value(next_obs).reshape(1, -1)
+ if args.gae:
+ advantages = torch.zeros_like(rewards).to(device)
+ lastgaelam = 0
+ for t in reversed(range(args.num_steps)):
+ if t == args.num_steps - 1:
+ nextnonterminal = 1.0 - next_done
+ nextvalues = next_value
+ else:
+ nextnonterminal = 1.0 - dones[t + 1]
+ nextvalues = values[t + 1]
+ delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]
+ advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
+ returns = advantages + values
+ else:
+ returns = torch.zeros_like(rewards).to(device)
+ for t in reversed(range(args.num_steps)):
+ if t == args.num_steps - 1:
+ nextnonterminal = 1.0 - next_done
+ next_return = next_value
+ else:
+ nextnonterminal = 1.0 - dones[t + 1]
+ next_return = returns[t + 1]
+ returns[t] = rewards[t] + args.gamma * nextnonterminal * next_return
+ advantages = returns - values
+
+ # flatten the batch
+ b_obs = obs.reshape((-1,) + envs.single_observation_space.shape)
+ b_logprobs = logprobs.reshape(-1)
+ b_actions = actions.reshape((-1,) + envs.single_action_space.shape)
+ b_advantages = advantages.reshape(-1)
+ b_returns = returns.reshape(-1)
+ b_values = values.reshape(-1)
+
+ # Optimizing the policy and value network
+ b_inds = np.arange(args.batch_size)
+ clipfracs = []
+ for epoch in range(args.update_epochs):
+ np.random.shuffle(b_inds)
+ for start in range(0, args.batch_size, args.minibatch_size):
+ end = start + args.minibatch_size
+ mb_inds = b_inds[start:end]
+
+ _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions.long()[mb_inds])
+ logratio = newlogprob - b_logprobs[mb_inds]
+ ratio = logratio.exp()
+
+ with torch.no_grad():
+ # calculate approx_kl http://joschu.net/blog/kl-approx.html
+ old_approx_kl = (-logratio).mean()
+ approx_kl = ((ratio - 1) - logratio).mean()
+ clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()]
+
+ mb_advantages = b_advantages[mb_inds]
+ if args.norm_adv:
+ mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)
+
+ # Policy loss
+ pg_loss1 = -mb_advantages * ratio
+ pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
+ pg_loss = torch.max(pg_loss1, pg_loss2).mean()
+
+ # Value loss
+ newvalue = newvalue.view(-1)
+ if args.clip_vloss:
+ v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2
+ v_clipped = b_values[mb_inds] + torch.clamp(
+ newvalue - b_values[mb_inds],
+ -args.clip_coef,
+ args.clip_coef,
+ )
+ v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2
+ v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
+ v_loss = 0.5 * v_loss_max.mean()
+ else:
+ v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean()
+
+ entropy_loss = entropy.mean()
+ loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
+
+ optimizer.zero_grad()
+ loss.backward()
+ nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
+ optimizer.step()
+
+ if args.target_kl is not None:
+ if approx_kl > args.target_kl:
+ break
+
+ y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
+ var_y = np.var(y_true)
+ explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
+
+ # Evaluate agent
+ mean_eval_returns, stddev_eval_returns, normalized_mean_eval_returns = level_replay_evaluate_minigrid(
+ args.env_id, agent, args.num_eval_episodes, device, num_levels=0
+ )
+ mean_train_returns, stddev_train_returns, normalized_mean_train_returns = level_replay_evaluate_minigrid(
+ args.env_id, agent, args.num_eval_episodes, device, num_levels=200
+ )
+
+ # TRY NOT TO MODIFY: record rewards for plotting purposes
+ writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step)
+ writer.add_scalar("charts/episode_returns", np.mean(episode_rewards), global_step)
+ writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
+ writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
+ writer.add_scalar("losses/entropy", entropy_loss.item(), global_step)
+ writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step)
+ writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
+ writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step)
+ writer.add_scalar("losses/explained_variance", explained_var, global_step)
+ print("SPS:", int(global_step / (time.time() - start_time)))
+ writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
+
+ writer.add_scalar("test_eval/mean_episode_return", mean_eval_returns, global_step)
+ writer.add_scalar("test_eval/normalized_mean_eval_return", normalized_mean_eval_returns, global_step)
+ writer.add_scalar("test_eval/stddev_eval_return", stddev_eval_returns, global_step)
+
+ writer.add_scalar("train_eval/mean_episode_return", mean_train_returns, global_step)
+ writer.add_scalar("train_eval/normalized_mean_train_return", normalized_mean_train_returns, global_step)
+ writer.add_scalar("train_eval/stddev_train_return", stddev_train_returns, global_step)
+
+ writer.add_scalar("curriculum/completed_episodes", completed_episodes, step)
+
+ envs.close()
+ writer.close()
diff --git a/tests/cleanrl_cartpole_test.sh b/tests/cleanrl_cartpole_test.sh
old mode 100755
new mode 100644