diff --git a/L03-perceptron/code/perceptron-numpy.ipynb b/L03-perceptron/code/perceptron-numpy.ipynb index f3d925e..4f3ed4e 100755 --- a/L03-perceptron/code/perceptron-numpy.ipynb +++ b/L03-perceptron/code/perceptron-numpy.ipynb @@ -23,15 +23,22 @@ "Sebastian Raschka \n", "\n", "CPython 3.7.1\n", - "IPython 7.10.1\n", + "IPython 7.11.1\n", "\n", - "numpy 1.17.4\n" + "torch 1.4.0\n" ] } ], "source": [ "%load_ext watermark\n", - "%watermark -a 'Sebastian Raschka' -v -p numpy" + "%watermark -a 'Sebastian Raschka' -v -p torch" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- Runs on CPU or GPU (if available)" ] }, { @@ -45,7 +52,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Implementation of the classic Perceptron by Frank Rosenblatt for binary classification (here: 0/1 class labels) in NumPy" + "Implementation of the classic Perceptron by Frank Rosenblatt for binary classification (here: 0/1 class labels) in PyTorch" ] }, { @@ -63,6 +70,7 @@ "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", + "import torch\n", "%matplotlib inline" ] }, @@ -189,15 +197,24 @@ "metadata": {}, "outputs": [], "source": [ + "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", + "\n", + "\n", "class Perceptron():\n", " def __init__(self, num_features):\n", " self.num_features = num_features\n", - " self.weights = np.zeros((num_features, 1), dtype=np.float)\n", - " self.bias = np.zeros(1, dtype=np.float)\n", + " self.weights = torch.zeros(num_features, 1, \n", + " dtype=torch.float32, device=device)\n", + " self.bias = torch.zeros(1, dtype=torch.float32, device=device)\n", + " \n", + " # placeholder vectors so they don't\n", + " # need to be recreated each time\n", + " self.ones = torch.ones(1)\n", + " self.zeros = torch.zeros(1)\n", "\n", " def forward(self, x):\n", - " linear = np.dot(x, self.weights) + self.bias\n", - " predictions = np.where(linear > 0., 1, 0)\n", + " linear = torch.add(torch.mm(x, self.weights), self.bias)\n", + " predictions = torch.where(linear > 0., self.ones, self.zeros)\n", " return predictions\n", " \n", " def backward(self, x, y): \n", @@ -209,13 +226,14 @@ " for e in range(epochs):\n", " \n", " for i in range(y.shape[0]):\n", + " # use view because backward expects a matrix (i.e., 2D tensor)\n", " errors = self.backward(x[i].reshape(1, self.num_features), y[i]).reshape(-1)\n", " self.weights += (errors * x[i]).reshape(self.num_features, 1)\n", " self.bias += errors\n", " \n", " def evaluate(self, x, y):\n", " predictions = self.forward(x).reshape(-1)\n", - " accuracy = np.sum(predictions == y) / y.shape[0]\n", + " accuracy = torch.sum(predictions == y).float() / y.shape[0]\n", " return accuracy" ] }, @@ -236,24 +254,23 @@ "output_type": "stream", "text": [ "Model parameters:\n", - "\n", - "\n", - " Weights: [[1.27340847]\n", - " [1.34642288]]\n", - "\n", - " Bias: [-1.]\n", - "\n" + " Weights: tensor([[1.2734],\n", + " [1.3464]])\n", + " Bias: tensor([-1.])\n" ] } ], "source": [ "ppn = Perceptron(num_features=2)\n", "\n", - "ppn.train(X_train, y_train, epochs=5)\n", + "X_train_tensor = torch.tensor(X_train, dtype=torch.float32, device=device)\n", + "y_train_tensor = torch.tensor(y_train, dtype=torch.float32, device=device)\n", + "\n", + "ppn.train(X_train_tensor, y_train_tensor, epochs=5)\n", "\n", - "print('Model parameters:\\n\\n')\n", - "print(' Weights: %s\\n' % ppn.weights)\n", - "print(' Bias: %s\\n' % ppn.bias)" + "print('Model parameters:')\n", + "print(' Weights: %s' % ppn.weights)\n", + "print(' Bias: %s' % ppn.bias)" ] }, { @@ -277,7 +294,10 @@ } ], "source": [ - "test_acc = ppn.evaluate(X_test, y_test)\n", + "X_test_tensor = torch.tensor(X_test, dtype=torch.float32, device=device)\n", + "y_test_tensor = torch.tensor(y_test, dtype=torch.float32, device=device)\n", + "\n", + "test_acc = ppn.evaluate(X_test_tensor, y_test_tensor)\n", "print('Test set accuracy: %.2f%%' % (test_acc*100))" ] }, diff --git a/L03-perceptron/code/perceptron-pytorch.ipynb b/L03-perceptron/code/perceptron-pytorch.ipynb deleted file mode 100755 index 4f3ed4e..0000000 --- a/L03-perceptron/code/perceptron-pytorch.ipynb +++ /dev/null @@ -1,387 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "STAT 453: Deep Learning (Spring 2020) \n", - "\n", - "Instructor: Sebastian Raschka (sraschka@wisc.edu) \n", - "Course website: http://pages.stat.wisc.edu/~sraschka/teaching/stat453-ss2020/ \n", - "GitHub repository: https://github.com/rasbt/stat453-deep-learning-ss20" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Sebastian Raschka \n", - "\n", - "CPython 3.7.1\n", - "IPython 7.11.1\n", - "\n", - "torch 1.4.0\n" - ] - } - ], - "source": [ - "%load_ext watermark\n", - "%watermark -a 'Sebastian Raschka' -v -p torch" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "- Runs on CPU or GPU (if available)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# L03: Perceptrons" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Implementation of the classic Perceptron by Frank Rosenblatt for binary classification (here: 0/1 class labels) in PyTorch" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Imports" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "import matplotlib.pyplot as plt\n", - "import torch\n", - "%matplotlib inline" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Preparing a toy dataset" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Class label counts: [50 50]\n", - "X.shape: (100, 2)\n", - "y.shape: (100,)\n" - ] - } - ], - "source": [ - "##########################\n", - "### DATASET\n", - "##########################\n", - "\n", - "data = np.genfromtxt('perceptron_toydata.txt', delimiter='\\t')\n", - "X, y = data[:, :2], data[:, 2]\n", - "y = y.astype(np.int)\n", - "\n", - "print('Class label counts:', np.bincount(y))\n", - "print('X.shape:', X.shape)\n", - "print('y.shape:', y.shape)\n", - "\n", - "# Shuffling & train/test split\n", - "shuffle_idx = np.arange(y.shape[0])\n", - "shuffle_rng = np.random.RandomState(123)\n", - "shuffle_rng.shuffle(shuffle_idx)\n", - "X, y = X[shuffle_idx], y[shuffle_idx]\n", - "\n", - "X_train, X_test = X[shuffle_idx[:70]], X[shuffle_idx[70:]]\n", - "y_train, y_test = y[shuffle_idx[:70]], y[shuffle_idx[70:]]\n", - "\n", - "# Normalize (mean zero, unit variance)\n", - "mu, sigma = X_train.mean(axis=0), X_train.std(axis=0)\n", - "X_train = (X_train - mu) / sigma\n", - "X_test = (X_test - mu) / sigma" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYQAAAEWCAYAAABmE+CbAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+17YcXAAAe80lEQVR4nO3de5hddX3v8feHODZDSImS9BRmggRCA5LrMQIpFeuFBj1yEfBBoWC85eh5BLw0CsZKAEFo8FboMYaiNRQVEExTrUZ5cipKRUhIgESIKIrMgIqxI2CCmSTf88deSSbJ7MvMrL3XZX9ezzPPM3vttdf6rgz8vnv9ft/fbykiMDMz2y/rAMzMLB+cEMzMDHBCMDOzhBOCmZkBTghmZpZwQjAzM8AJwQwASaMkPSfp0DT3NSsSJwQrpKRB3vmzQ9KWAa/PHerxImJ7RBwQEb9Mc99WkfQDSfOyjsOK7QVZB2A2HBFxwM7fJf0CeGdE3Fltf0kviIhtrYjNrKh8h2ClJOnjkm6R9BVJzwJ/K2mOpHsk9Ul6StI/SupI9n+BpJB0WPL6X5P3vyXpWUk/lDRpqPsm779O0k8k/V7SdZLurvZtXtLxku6X9IykX0taPOC9EwbEv07Sicn2a4A5wJLkDukz6f+LWjtwQrAyeyPwZeBA4BZgG3ARMB44ATgZ+N81Pn8O8PfAi4FfAlcMdV9JfwbcCixIzvtz4Ngax7kOWBwRfwpMBr6WHGcisAK4NDnHxcAdkg6KiA8DPwTenXRlva/G8c2qckKwMvtBRPx7ROyIiC0RcV9E/CgitkXEY8BS4JU1Pv+1iFgdEf3AzcDMYez7BmBdRPxb8t6ngd/WOE4/cGTS0D8bET9Ktp8PrIiIlcn1fBt4gEpSM0uFE4KV2RMDX0g6StI3Jf1K0jPA5VS+tVfzqwG/bwYOqLZjjX0PGRhHVFaT7KlxnLcBLwU2SrpX0uuT7S8B3pJ0F/VJ6gOOT45vlgonBCuzvZfy/TywHpicdMl8DFCTY3gK6N75QpKArmo7R8TGiHgz8GfAJ4HbJY2mklS+GBHjBvyMiYidYwxetthGzAnB2slY4PfAHyQdTe3xg7R8A/ifkk6R9AIqYxgTqu0s6TxJ4yNiRxJrADuAm4A3SjopmQcxWtKrJO28Q/g1cHhzL8XKzgnB2skHgbcCz1K5W7il2SeMiF8DZwOfAjYBRwBrgT9W+cjrgYeTyqhrgbMjYmtE/ILKIPnfA09TGbj+ILv/H/4Mu7uUPtWky7GSkx+QY9Y6kkYBTwJnRcT3s47HbCDfIZg1maSTJR0o6U+ofMPfBtybcVhm+8gsISR9oPdKekDSBkmXZRWLWZP9FfAYlXLTk4HTI6Jal5FZZjLrMkqqLcZExHPJbNEfABdFxD2ZBGRm1uYyW8soqcd+LnnZkfx4QMPMLCOZLm6XDLCtoTJF/58GzMocuM98YD7AmDFjXnbUUUe1Nkgzs4Jbs2bNbyOiarnzTrmoMpI0Dvg6cEFErK+23+zZs2P16tWtC8zMrAQkrYmI2fX2y0WVUUT0Af+J12UxM8tMllVGE5I7AyR1Aq8FHskqHjOzdpflGMLBwJeScYT9gFsj4hsZxmNm1tayrDJ6EJiV1fnNrLj6+/vp6enh+eefzzqUXBk9ejTd3d10dHQM6/N+hKaZFU5PTw9jx47lsMMOozKlySKCTZs20dPTw6RJk+p/YBC5GFQ2MxuK559/noMOOsjJYABJHHTQQSO6a3JCMLNCcjLY10j/TZwQzMwMcEIwM0vNokWLuPbaa5ty7DVr1jBt2jQmT57MhRdeSDMmFTshmJkVwHve8x6WLl3Ko48+yqOPPsq3v/3t1M/hhGBmpbd8bS8nXL2KSRd/kxOuXsXytb0jPuayZcuYPn06M2bM4Lzzztvn/RtuuIGXv/zlzJgxgzPPPJPNmzcDcNtttzF16lRmzJjBiSeeCMCGDRs49thjmTlzJtOnT+fRRx/d41hPPfUUzzzzDHPmzEES559/PsuXLx/xNezNZadmVmrL1/ZyyR0PsaV/OwC9fVu45I6HADh9VtewjrlhwwauvPJK7r77bsaPH8/vfve7ffY544wzeNe73gXARz/6UW688UYuuOACLr/8clauXElXVxd9fX0ALFmyhIsuuohzzz2XrVu3sn379j2O1dvbS3d3967X3d3d9PaOPKntzXcIZlZqi1du3JUMdtrSv53FKzcO+5irVq3irLPOYvz48QC8+MUv3mef9evX84pXvIJp06Zx8803s2HDBgBOOOEE5s2bxw033LCr4Z8zZw5XXXUV11xzDY8//jidnZ17HGuw8YJmVFk5IZhZqT3Zt2VI2xsREXUb5Hnz5nH99dfz0EMPcemll+6aH7BkyRI+/vGP88QTTzBz5kw2bdrEOeecw4oVK+js7GTu3LmsWrVqj2N1d3fT09Oz63VPTw+HHHLIsOOvxgnBzErtkHGdQ9reiNe85jXceuutbNq0CWDQLqNnn32Wgw8+mP7+fm6++eZd23/2s59x3HHHcfnllzN+/HieeOIJHnvsMQ4//HAuvPBCTj31VB588ME9jnXwwQczduxY7rnnHiKCZcuWcdpppw07/mqcEMys1BbMnUJnx6g9tnV2jGLB3CnDPuYxxxzDwoULeeUrX8mMGTP4wAc+sM8+V1xxBccddxwnnXQSAx/stWDBAqZNm8bUqVM58cQTmTFjBrfccgtTp05l5syZPPLII5x//vn7HO9zn/sc73znO5k8eTJHHHEEr3vd64YdfzW5eEBOo/yAHDMDePjhhzn66KMb3n/52l4Wr9zIk31bOGRcJwvmThn2gHLeDfZv0+gDclxlZGald/qsrtImgDS5y8jMzAAnBDMzSzghmJkZ4IRgZmYJJwQzMwOcEMzMUtPM5a8XLlzIxIkTOeCAA5pyfHBCMDMrhFNOOYV77723qefwPAQzK7erumDrc/tuf+EB8JHhrxi6bNkyrr32WiQxffp0brrppj3ev+GGG1i6dClbt25l8uTJ3HTTTey///7cdtttXHbZZYwaNYoDDzyQu+66iw0bNvC2t72NrVu3smPHDm6//XaOPPLIPY53/PHHDzvWRjkhmFm5DZYMam1vQKuXv24VdxmZmQ1Rq5e/bhUnBDOzIWr18tet4oRgZjZErV7+ulWcEMzMhiiL5a8/9KEP0d3dzebNm+nu7mbRokWpX1dmy19LmggsA/4c2AEsjYjP1vqMl782Mxji8tdNqjLKq6Iuf70N+GBE3C9pLLBG0ncj4scZxmRmZVPCRr9ZMusyioinIuL+5PdngYcBL1huZpaRXIwhSDoMmAX8KNtIzKwoivS0x1YZ6b9J5glB0gHA7cD7IuKZQd6fL2m1pNVPP/106wM0s9wZPXo0mzZtclIYICLYtGkTo0ePHvYxMn2msqQO4BvAyoj4VL39PahsZgD9/f309PTsqu23itGjR9Pd3U1HR8ce23M/qKzKrI4bgYcbSQZmZjt1dHQwadKkrMMonSy7jE4AzgNeLWld8vP6DOMxM2trmd0hRMQPgNpzv83MrGW82qlZETUy2arNJmTZyGVeZWRmw9DIks619ll0YCVhmA3ghGDWrkbwPAArJycEMzMDPIZgZuDxBgN8h2Bm0JTHTFrxOCGYFdELD6i/vdo+ZlW4y8isiBrpxmmk/NRsACcEs7LzGIA1yF1GZmYG+A7BzKDSfVSvW8mVSKXnhGBmjTXorkQqPXcZmZkZ4IRgZmYJJwQzMwOcEMzMLOGEYGaNaWR2tBWaq4zMrDH1KpGKWJZaxJibyAnBzNKRdVnqcBr3rGPOGScEs3ZT1m/FbtxHzGMIZu3GDadV4YRgZmaAu4zMiqHRbp6ydgdZS/gOwawIGu3mybI7qIhlqUWMuYl8h2BWdIsOzDqCiqzvQBpZsXVvWcecM04IZu1mOA1nEbhxHzEnBLMsZNnX74bTqvAYglkWXPppOZTpHYKkLwBvAH4TEVOzjMUs16p18wzl81lzBVTuZd1l9C/A9cCyjOMwy7dqDWatAeVFv29OLMPlu6Lcy7TLKCLuAn6XZQxmheaySUtR1ncIdUmaD8wHOPTQQzOOxixn3NViKcp9QoiIpcBSgNmzZ0fG4Zilo9HSz3bpd7+qq1zXU1C5TwhmpdRo4zfSfveiJBSPI+SCy07NyixPA7ke18i9TBOCpK8APwSmSOqR9I4s4zGzJsrTHYkNKtMuo4h4S5bnNyuNal1DWZ47b91SVpe7jMzKIMs++FrdUld1tTYWGxEnBLM8K/o8g72TRdGvp+RcZWSWZ83qcsmqAXYXUq45IZiVXd6WsLDcckIwywMPzFoOeAzBLA9GOl8gy7559/+Xhu8QzBqwfG0vi1du5Mm+LRwyrpMFc6dw+qwcVdBkeRfxkd7adzhWGE4IZnUsX9vLJXc8xJb+7QD09m3hkjseAshXUshSvYTkLrFCcEIwq2Pxyo27ksFOW/q3s3jlxuwTQhoNbSsa6zwtoWFVOSGY1fFk35YhbW+pRhraeg2+G2tLVB1UljRR0lclfV/SRyR1DHhveWvCM8veIeM6h7R9WJoxKLxzlrAbfGtQrTuELwC3A/cA7wC+J+mUiNgEvKQVwZnlwYK5U/YYQwDo7BjFgrlT0jtJM/rR3eDbENVKCBMiYkny+wWS/ha4S9KpgB9UY21j5zhBrquMzFJQKyF0SBodEc8DRMS/SvoVsBIY05LobA+5L30ssdNndfnfeiQafUKcZapWQvhn4Djgezs3RMSdkt4E/EOzA7M9ufSxHFJP6tUa2jSOkWZj7dLSQqiaECLi01W2rwVOalpENqhclz5aQ5qS1D/SC4sOrL1PvQZ/qI11rWcvDLVU1fMTcsVlpwWR69JHa0jTknpaDX6jjXOtO5Kh3q24AipXnBAK4pBxnfQO0vinWvqYQ2UaN2laUvfkMUuJF7criAVzp9DZMWqPbamXPubMzi6W3r4tBLu7WJavLWZXQkvmM5iNQN2EIOl/SLpR0reS1y+V9I7mh2YDnT6ri0+cMY2ucZ0I6BrXySfOmFbYb8uNqNXFUkTtmNStWBrpMvoX4IvAwuT1T4BbgBubFJNV0W6lj2UbN8l8PoMHcK2ORhLC+Ii4VdIlABGxTdL2eh8yG6kyjptkmtTTGiOoVeo61FJVz0/IlUYSwh8kHUQyO1nS8YCfyWdN15IlI2y3RhvnNO8mfGeSK40khA8AK4AjJN0NTADOampUZuSgi6XduHFuezUTgqT9gNHAK4EpgICNEdHfgtjM2m7cpPA8TlFoNRNCROyQ9MmImANsaFFMZlWVaV5CKXkuQ6E1Mg/hO5LOlKSmR2NWQ9nmJbRcM565YKXS6BjCGGCbpOepdBtFRPzpSE8u6WTgs8Ao4J8j4uqRHtPKy+s5jZC7bKyOugkhIsY248SSRgH/RGWhvB7gPkkrIuLHzTifFV/Z5iWY5U3dhCDpxMG2R8RdIzz3scBPI+Kx5DxfBU4DnBBsUGWcl2CWJ410GS0Y8PtoKg35GuDVIzx3F/DEgNc9VJ6/sAdJ84H5AIceeugIT2lZSGsg2PMSCsATzQqtkS6jUwa+ljSRdB6QM9gg9T6P5oyIpcBSgNmzZ/vRnQWT5jMA8j4vwRVQeJyi4Iaz/HUPMDWFc/cAEwe87gaeTOG4liNpDwTndV6Cn2hnZdDIGMJ17P7mvh8wE3gghXPfBxwpaRLQC7wZOCeF41qOtMtAsCugrAwauUNYPeD3bcBXIuLukZ44WSTvvcBKKmWnX4gIT34rmXYZCG6XxGfl1sjEtHER8aXk5+aIuFvSRWmcPCL+IyL+IiKOiIgr0zim5Uu7PAPAD7+xMmgkIbx1kG3zUo7DSqqZD/ZZvraXE65exaSLv8kJV6/KdMZyuyQ+K7eqXUaS3kKlT3+SpBUD3hoLbGp2YFYezRgIztsgbt4roMwaUWsM4b+Ap4DxwCcHbH8WeLCZQVnzFb1EMo+DuHmtgDJrVNWEEBGPA48Dc1oXjrVC3r5dD4cHcc3SV3cMQdLxku6T9JykrZK2S3qmFcFZc5Th4fUexDVLXyODytcDbwEeBTqBdwLXNTMoa46dg7CDlYFCsb5dexDXLH0NzVSOiJ9KGhUR24EvSvqvJsdlKdu7m2gwRfp27UFcs/Q1khA2S3ohsE7SP1AZaB7T3LAsbYN1Ew1UxG/XHsQ1S1cjCeE8Kl1L7wXeT2X9oTObGVTeFbFCp1Z3UFdBrsHMmquR1U4fl9QJHBwRl7UgplwraoVOtSUkusZ1cvfFI13J3MzKoJEqo1OAdcC3k9cz95qo1laKWqHjQVgzq6eRKqNFVB6K0wcQEeuAw5oXUr4Vtf69mUtImFk5NDKGsC0ifi8N9jyb9lPk1Ts9CGtmtTRyh7Be0jnAKElHJs9HaNuyU3e9tEaeFq4zaxeN3CFcACwE/gh8mcrzCz7ezKDyrN3q31tdUbV8bS+LVmygb0v/rm1FGbhPUxEr2az4FDH4Y4ol3RQR50m6KCI+2+K4BjV79uxYvXp1/R0tFYNNZuvsGNW0sYd6k+d2VkSVvbFs9b+7lZ+kNRExu95+tbqMXibpJcDbJb1I0osH/qQXquVVqyuq6k2ee7Jvy67GsrdvC8Huu4fBupSK2u1U1Eo2K75aXUZLqJSaHg6sAQaOKkey3Uqs1RVV9Y67n8T7blm3z/bBlr0u6nwRKG4lmxVf1TuEiPjHiDiayrOOD4+ISQN+nAzaQKtXFK133O1Vujdh38ayyN+yvZKrZaVulVFEvKcVgVj+tLqiarDzAezXQMXz3o1lkb9lu5LNstJI2am1qbQns9Xr0x/sfJ85eyY1bgyAwRvLIn/L9iRCy0rVKqM8cpVRcY2kcqbWMxyqLcw3lPPtXer6ov07uPSUY9wAW2mkUWVklppqffqLVmyo+9lqXSifOXsmd1/86kEb7ka/ZS9f28uC2x7YY97Df2/uZ8HXHihMVZJZWhp6QI61pzTr/av13fdt6Wf52t6axx3uZMBGlupYvHIj/Tv2vUvu3x77VC6ZlZ0Tgg0q7bLNamtAAQ01vM1ah6nWIHMRBqDN0uQuIxtU2mWbtSpksmx4aw0yF2EA2ixNTgg2qLTLNk+f1cWL9u8Y9L0sG94Fc6fQMUhda8couczT2k4mCUHSmyRtkLRDUt2Rb2u9ZpRtXnrKMbmrrz99VheL3zSDcZ27k9WL9u9g8VkzdnVRFXUJDLOhymoMYT1wBvD5jM5vdSyYO2XQss2RNN7DHRxu9mJ2tcYnGhlLKftie9Y+MkkIEfEwgB+6M3ytaCQh/WW+hzo4nPWaRLXGUk6f1TXk+Jw8LM9yX2UkaT4wH+DQQw/NOJp8aFUjmYcnrNVrkJut3ljKUOLLOrmZ1dO0MQRJd0paP8jPaUM5TkQsjYjZETF7woQJzQq3UIq4cNtw++GzXpOo3ljKUOIr4t/N2kvT7hAi4rXNOna7y7qRHGj52l4u+/cN/PfmykzfcZ0dLDp1z2UfRvLNOOtnWNcbSxlKfHn6u5kNxmWnBZSXhduWr+1lwdce2JUMoDLzeMFtey77MJJvxtWWrXjVURNaUvlTbwmMoaxMmpe/m1k1mSxuJ+mNwHXABKAPWBcRc+t9zovbVeTlEYu1Fp0bJbEjouYMZQE/v/p/1T3P3gOxrzpqArev6c38+qvFV22gOC9/N2s/jS5u59VOCyoP1SqTLv4mjfzXIxh0v53PSB6qaolouMdrpTz83az9NJoQcl9lZIPLQwVQrW//AwX7JoWRzGkocl98Hv5uZtV4DMGGbcHcKXSMamwuSUBqD3xxX7xZc/gOwYZtZ4M+sMpIYtAnnKXZndOMWdRm5oRgI7R3F0i1gdM0G+tmzaI2a3dOCJaqVjXW9dYfcrIwGzonBEtdlgOnXh7CbPg8qGyl4uUhzIbPCcFKpcglqWZZc0KwUnFJqtnwOSFYqQxlbSEz25MHla1UXJJqNnxOCFY6Xh7CbHjcZWRmZoATgpmZJZwQzMwMcEIwM7OEB5UtdV5LyKyYnBAsVV5LyKy43GVkqfJaQmbF5YRgqfJaQmbF5YRgqfJaQmbF5YRgqfJaQmbF5UFlS5XXEjIrLicES53XEjIrJncZmZkZ4DsEKzlPkjNrnBOClZYnyZkNTSZdRpIWS3pE0oOSvi5pXBZxWLl5kpzZ0GQ1hvBdYGpETAd+AlySURxWYp4kZzY0mXQZRcR3Bry8BzgrizhsaIrWH3/IuE56B2n8PUnObHB5qDJ6O/CtrIOw2nb2x/f2bSHY3R+/fG1v1qFV5UlyZkPTtIQg6U5J6wf5OW3APguBbcDNNY4zX9JqSauffvrpZoVrdRSxP/70WV184oxpdI3rREDXuE4+cca0XN/VmGWpaV1GEfHaWu9LeivwBuA1ERE1jrMUWAowe/bsqvtZcxW1P96T5Mwal1WV0cnAh4FTI2JzFjHY0HjROrPyy2oM4XpgLPBdSeskLckoDmuQ++PNyi+rKqPJWZzXhm5gZdGBnR2M7tiPvs39hagyMrOh8Uxlq2rvmb59W/rp7BjFp8+e6URgVkJ5KDu1nCpiZZGZDZ8TglVV1MoiMxseJwSrypVFZu3FCcGqcmWRWXvxoLJV5cdhmrUXJwSryTN9zdqHu4zMzAxwQjAzs4QTgpmZAU4IZmaWcEIwMzPACcHMzBJOCGZmBjghmJlZwgnBzMwAJwQzM0s4IZiZGeCEYGZmCScEMzMDnBDMzCzhhGBmZoATgpmZJZwQzMwMcEIwM7OEE4KZmQFOCGZmlnBCMDMzIKOEIOkKSQ9KWifpO5IOySIOMzPbLas7hMURMT0iZgLfAD6WURxmZpbIJCFExDMDXo4BIos4zMxstxdkdWJJVwLnA78HXlVjv/nA/OTlHyWtb0F4WRkP/DbrIJqozNdX5msDX1/RTWlkJ0U058u5pDuBPx/krYUR8W8D9rsEGB0RlzZwzNURMTvFMHPF11dcZb428PUVXaPX17Q7hIh4bYO7fhn4JlA3IZiZWfNkVWV05ICXpwKPZBGHmZntltUYwtWSpgA7gMeBdzf4uaXNCykXfH3FVeZrA19f0TV0fU0bQzAzs2LxTGUzMwOcEMzMLFG4hFDmZS8kLZb0SHJ9X5c0LuuY0iTpTZI2SNohqTQlfpJOlrRR0k8lXZx1PGmS9AVJvynr/B9JEyX9P0kPJ/9tXpR1TGmRNFrSvZIeSK7tsrqfKdoYgqQ/3TnTWdKFwEsjotFB6VyT9DfAqojYJukagIj4cMZhpUbS0VQKCT4P/F1ErM44pBGTNAr4CXAS0APcB7wlIn6caWApkXQi8BywLCKmZh1P2iQdDBwcEfdLGgusAU4vw99PkoAxEfGcpA7gB8BFEXFPtc8U7g6hzMteRMR3ImJb8vIeoDvLeNIWEQ9HxMas40jZscBPI+KxiNgKfBU4LeOYUhMRdwG/yzqOZomIpyLi/uT3Z4GHga5so0pHVDyXvOxIfmq2l4VLCFBZ9kLSE8C5lHdhvLcD38o6CKurC3hiwOseStKgtBtJhwGzgB9lG0l6JI2StA74DfDdiKh5bblMCJLulLR+kJ/TACJiYURMBG4G3ptttENT79qSfRYC26hcX6E0cn0lo0G2leautV1IOgC4HXjfXr0QhRYR25NVpbuBYyXV7PbLbHG7Wsq87EW9a5P0VuANwGuiaAM8DOlvVxY9wMQBr7uBJzOKxYYh6V+/Hbg5Iu7IOp5miIg+Sf8JnAxULRDI5R1CLWVe9kLSycCHgVMjYnPW8VhD7gOOlDRJ0guBNwMrMo7JGpQMvN4IPBwRn8o6njRJmrCzUlFSJ/Ba6rSXRawyup3KUq67lr2IiN5so0qHpJ8CfwJsSjbdU5YKKgBJbwSuAyYAfcC6iJibbVQjJ+n1wGeAUcAXIuLKjENKjaSvAH9NZXnoXwOXRsSNmQaVIkl/BXwfeIhKmwLwkYj4j+yiSoek6cCXqPx3uR9wa0RcXvMzRUsIZmbWHIXrMjIzs+ZwQjAzM8AJwczMEk4IZmYGOCGYmVnCCcHaiqQLk5UthzwLXNJhks5pRlzJ8U+UdL+kbZLOatZ5zKpxQrB283+A10fEucP47GHAkBNCsiJqI34JzKMyA9+s5ZwQrG1IWgIcDqyQ9H5JY5L1/u+TtHbnekvJncD3k2/r90v6y+QQVwOvSJ7F8X5J8yRdP+D435D018nvz0m6XNKPgDmSXibpe5LWSFqZLLu8h4j4RUQ8yO4JUmYtlcu1jMyaISLenSwP8qqI+K2kq6g8f+LtyRT/eyXdSWVlyJMi4vlkqZSvALOBi6k8x+ENAJLm1TjdGGB9RHwsWSvne8BpEfG0pLOBK6msaGuWG04I1s7+BjhV0t8lr0cDh1JZnO56STOB7cBfDOPY26ksmAaVpVamAt+tLJ3DKOCpEcRt1hROCNbOBJy590N7JC2ism7PDCrdqs9X+fw29ux2HT3g9+cjYvuA82yIiDlpBG3WLB5DsHa2ErggWfESSbOS7QcCT0XEDuA8Kt/oAZ4Fxg74/C+AmZL2kzSRytPTBrMRmCBpTnKeDknHpHolZilwQrB2dgWVxwo+qMpD5K9Itv9f4K2S7qHSXfSHZPuDwLbkoeXvB+4Gfk5lpcxrgfsHO0nyaM2zgGskPQCsA/5y7/0kvVxSD/Am4POSNqRzmWaN8WqnZmYG+A7BzMwSTghmZgY4IZiZWcIJwczMACcEMzNLOCGYmRnghGBmZon/D4tjyOlImOFsAAAAAElFTkSuQmCC\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "plt.scatter(X_train[y_train==0, 0], X_train[y_train==0, 1], label='class 0', marker='o')\n", - "plt.scatter(X_train[y_train==1, 0], X_train[y_train==1, 1], label='class 1', marker='s')\n", - "plt.title('Training set')\n", - "plt.xlabel('feature 1')\n", - "plt.ylabel('feature 2')\n", - "plt.xlim([-3, 3])\n", - "plt.ylim([-3, 3])\n", - "plt.legend()\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYQAAAEWCAYAAABmE+CbAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+17YcXAAAaWUlEQVR4nO3dfZRV9X3v8feH6diZAJEo3AozGFEsMeHxZnygJphWDZoVH4pmNWolJNdwk7siJrYkGtKIGo2KafpgGwrFtlhuol6VuBLrRBdNTMxCBSHAFAnGxsuMpCHjnQgBwtP3/nE2MAMzZ87Medjn4fNaa9Y6Z59z9v5udJ3P+T3s31ZEYGZmNiTtAszMrDw4EMzMDHAgmJlZwoFgZmaAA8HMzBIOBDMzAxwIZmaWcCBYTZC0q9vfIUl7uj2/Lo/9rpb0p4WsNdnvpyQ9W+j9mmXzO2kXYFYKETHs8GNJPwduiAh/4Zp14xaCGSCpTtJfSHpN0q8krZA0InltqKRvSXpTUpekFyS9Q9LXgLOBf0xaGl/rZb+9fjZ57SRJyyX9QtI2SbdJGiJpGvBXwAeS/f6ilP8WVrscCGYZ84EPAu8DmoH9wNeT124g05puAkYCnwH2RcSfAS+RaW0MS54fq9fPJq+tAH4NnA6cA1wJXB8R64DPAt9P9ntKgc/VrFcOBLOM/wncEhFvRMRe4HbgTySJTDiMAs6IiAMR8VJE/CbH/fb6WUnvBGYAN0fE7ojYDvwN8NGCn5lZjjyGYDUv+dIfCzwlqftqj0OAk4FlwCnA/5E0DFgO/EVEHMxh971+Fngn0ADsyBz+yPFezf+MzAbHgWA1LyJCUgcwKyLW9vG2LwNflnQ60Aq0kenyybpccET8to/P/hjYBbwjel9y2MsQW8m5y8gsYzFwj6SxAJL+m6TLkscXSXq3pCHAW8AB4HDr4L/IjAH0qq/PRsR/AquB+yQNTwaTz5T0vm77HSupvgjnatYrB4JZxn3As8AqSTvJ/IL/78lrTcC3gZ3AJuAp4JHkta8DsyX9P0n39bLfbJ+9BhgBvAK8CTwM/F7y2tPAz4FfSmovzCmaZSffIMfMzMAtBDMzS6QWCJIaJL0o6SeS2iTdnlYtZmaWYpdRMtVvaETsSgbOfgTcFBGrUynIzKzGpTbtNJlqtyt5Wp/8eUDDzCwlqV6HIKkOWAuMB/4uIl7o5T1zgbkAQ4cOfe+73vWu0hZpZlbh1q5d+6uIGNXf+8pillGyiNgTwI0Rsamv97W0tMSaNWtKV5iZWRWQtDYiWvp7X1nMMoqILuD7wCUpl2JmVrPSnGU0qtvywo3ARWQu0DEzsxSkOYYwGviXZBxhCPBIRHwnxXrMzGpamrOMNgDT0jq+mVWu/fv3097ezt69e9Mupaw0NDTQ3NxMff3glsDyaqdmVnHa29sZPnw4p512Gt2WD69pEUFnZyft7e2MGzduUPsoi0FlM7OB2Lt3LyeffLLDoBtJnHzyyXm1mhwIZlaRHAbHy/ffxIFgZmaAA8HMrGAWLlzI/fffX5R9r127lkmTJjF+/HjmzZtHMS4qdiCYmVWAT3/60yxZsoStW7eydetWnn766YIfw4FgZlVv5boOzr9nFeNu+S7n37OKles68t7n8uXLmTx5MlOmTOH6668/7vWlS5dy9tlnM2XKFK666ip2794NwKOPPsrEiROZMmUKM2bMAKCtrY1zzjmHqVOnMnnyZLZu3dpjX9u3b+ett95i+vTpSGL27NmsXLky73M4lqedmllVW7mug1sf38ie/ZnbYHd07eHWxzcCcOW0pkHts62tjbvuuovnn3+ekSNH8uabbx73nlmzZvHJT34SgC996UssW7aMG2+8kTvuuIPW1laampro6uoCYPHixdx0001cd9117Nu3j4MHD/bYV0dHB83NzUeeNzc309GRf6gdyy0EM6tqi1q3HAmDw/bsP8ii1i2D3ueqVau4+uqrGTlyJAAnnXTSce/ZtGkT73//+5k0aRIrVqygra0NgPPPP585c+awdOnSI1/806dP5+677+bee+/l9ddfp7Gxsce+ehsvKMYsKweCmVW1N7r2DGh7LiKi3y/kOXPm8MADD7Bx40Zuu+22I9cHLF68mK985Sts27aNqVOn0tnZybXXXsuTTz5JY2MjM2fOZNWqVT321dzcTHt7+5Hn7e3tjBkzZtD198WBYGZVbcyIxgFtz8WFF17II488QmdnJ0CvXUY7d+5k9OjR7N+/nxUrVhzZ/rOf/Yxzzz2XO+64g5EjR7Jt2zZee+01Tj/9dObNm8fll1/Ohg0beuxr9OjRDB8+nNWrVxMRLF++nCuuuGLQ9ffFgWBmVW3+zAk01tf12NZYX8f8mRMGvc/3vOc9LFiwgAsuuIApU6Zw8803H/eeO++8k3PPPZeLL76Y7jf2mj9/PpMmTWLixInMmDGDKVOm8PDDDzNx4kSmTp3KK6+8wuzZs4/b3ze+8Q1uuOEGxo8fzxlnnMGll1466Pr7UhY3yMmVb5BjZgCbN2/mrLPOyvn9K9d1sKh1C2907WHMiEbmz5ww6AHlctfbv02uN8jxLCMzq3pXTmuq2gAoJHcZmZkZ4EAwM7OEA8HMzAAHgpmZJRwIZmYGOBDMzAqmmMtfL1iwgLFjxzJs2LCi7B8cCGZmFeGyyy7jxRdfLOoxfB2CmVW3u5tg367jt58wDL44+BVDly9fzv33348kJk+ezEMPPdTj9aVLl7JkyRL27dvH+PHjeeihh3jb297Go48+yu23305dXR0nnngizz33HG1tbXz84x9n3759HDp0iMcee4wzzzyzx/7OO++8QdeaKweCmVW33sIg2/YclHr561Jxl5GZ2QCVevnrUnEgmJkNUKmXvy4VB4KZ2QCVevnrUnEgmJkNUBrLX3/+85+nubmZ3bt309zczMKFCwt+Xqktfy1pLLAcOAU4BCyJiL/O9hkvf21mMMDlr4s0y6hcVery1weAP4uIlyUNB9ZKeiYi/iPFmsys2lThl36xpNZlFBHbI+Ll5PFOYDPgBcvNzFJSFmMIkk4DpgEvpFuJmVWKSrrbY6nk+2+SeiBIGgY8Bnw2It7q5fW5ktZIWrNjx47SF2hmZaehoYHOzk6HQjcRQWdnJw0NDYPeR6r3VJZUD3wHaI2Iv+zv/R5UNjOA/fv3097efmRuv2U0NDTQ3NxMfX19j+1lP6iszFUdy4DNuYSBmdlh9fX1jBs3Lu0yqk6aXUbnA9cDfyRpffL3oRTrMTOraam1ECLiR0D2a7/NzKxkvNqpWTmqsYuprDykPsvIzHpRhCWbzfrjQDAzM8CBYGZmCQeCmZkBDgQzM0s4EMzK0QnDBrbdrAA87dSsHA1kaqmnqFqBOBDMKp2nqB7lcMyLu4zMrHo4HPPiFoKZ1YaFJx6/zS2HHtxCMLPa5ZZDDw4EMzMDHAhmlc9TVK1APIZgVmk8k6ZvJwxzN1AeHAhmlcYzafrWVyD2NqBsx3GXkZlVP3er5cQtBLNy4u6g4vC/XU7cQjArJ+4OshS5hWBWKe5uqqxfum7tVBy3EMwqxeEv10rpD3drp+K4hWBWafzr2orEgWBmfXO3T01xl5FZOXG3j6XIgWBWTvyr21LkQDArN5UyaNyfajmPGuIxBLNyU+mtBI87VCy3EMyssDzuULFSbSFIehD4MPDLiJiYZi1m1ou+Vg8tdLePWxVlIe0uo38GHgCWp1yHWfrK8UuxVMd1q6IspNplFBHPAW+mWYNZ2fCXoqWs7McQJM2VtEbSmh07dqRdjplZ1Sr7QIiIJRHREhEto0aNSrscM+uPp5tWrLTHEMysUMplDMKDwBWr7FsIZpajSh6DcKuiLKQ97fSbwAeAkZLagdsiYlmaNZmlplRTPMuRWxVlIdVAiIhr0jy+WVmphS/FcunWsl65y8jMSqeSu7VqgAPBzMwAB4JZ9fDArOXJ007NqoX74C1PDgSzcuUBWCsxdxmZlatqHIB1t1ZZcwvBzErHLZuy5haCmZkBDgQzM0s4EMzMDMgSCJLGSvqWpB9K+qKk+m6vrSxNeWY1zAOwVmLZBpUfBB4DVgP/A/iBpMsiohN4ZymKM6tpHoC1EssWCKMiYnHy+EZJfwo8J+lyIIpfmpmZlVK2QKiX1BARewEi4l8l/QJoBYaWpDozMyuZbIPK/wic231DRDwLfATYVMyizMys9PpsIUTE1/vYvg64uGgVmZlZKjzt1MzMAAeCmZklHAhmZgbksLidpN8D7gbGRMSlkt4NTI+IZUWvzswsV14uPG+5tBD+mcxU0zHJ858Cny1WQWZmg1KNy4WXWC6BMDIiHgEOAUTEAeBgUasyM7OSyyUQfiPpZJKrkyWdB/y6qFWZmVnJ5XKDnJuBJ4EzJD0PjAKuLmpVZnY895FbkWUNBElDgAbgAmACIGBLROwvQW1mVWvlug4WtW7hja49jBnRyPyZE7hyWlP2D7mP3IosayBExCFJX4uI6UBbiWoyq2or13Vw6+Mb2bM/MxTX0bWHWx/fCNB/KFjfThjWdwvKcpJLl9H3JF0FPB4RXuXULE+LWrccCYPD9uw/yKLWLQ6EfLjbLG+5DCrfDDwK/FbSW5J2SnqrEAeXdImkLZJelXRLIfZpVu7e6NozoO1mpdJvIETE8IgYEhEnRMTbk+dvz/fAkuqAvwMuBd4NXJNc9GZW1caMaBzQdrNS6TcQJM3o7a8Axz4HeDUiXouIfcC3gCsKsF+zsjZ/5gQa6+t6bGusr2P+zAnZP+hbalqR5TKGML/b4wYyX+RrgT/K89hNwLZuz9s55v4LAJLmAnMBTj311DwPaZa+w+MEA55l5D5yK7J+AyEiLuv+XNJY4L4CHFu9Ha6X4y8BlgC0tLR4UNuqwpXTmjyAbGVnMKudtgMTC3DsdmBst+fNwBsF2K+ZmQ1CLqud/i1Hf7kPAaYCPynAsV8CzpQ0DugAPgpcW4D9mpnZIOQyhrCm2+MDwDcj4vl8DxwRByR9hsxKqnXAgxHhi99SMKirZq08eXkLy0MugTAiIv66+wZJNx27bTAi4ingqXz3Y4Pnq2arjJe3sDzkMobwsV62zSlwHZaSbFfNmllt6bOFIOkaMn364yQ92e2l4UBnsQuz0vBVs2Z2WLYuox8D24GRwNe6bd8JbChmUVY6Y0Y00tHLl7+vmjWrPX0GQkS8DrwOTC9dOVZq82dO6DGGADleNWtmVSeXpSvOk/SSpF2S9kk6WKjF7Sx9V05r4quzJtE0ohEBTSMa+eqsSR5QrlRe3sLykMssowfIXCPwKNACzAbGF7MoKy1fNVtFPLXU8pBLIBARr0qqi4iDwD9J+nGR6zIzsxLLJRB2SzoBWC/pPjIDzUOLW5bZ4PgiO7PBy+U6hOuT930G+A2Z9YeuKmZRZoNx+CK7jq49BEcvslu5zt0oZrnI5QY5r5NZmXR0RNweETdHxKvFL81sYHyRnVl+cplldBmwHng6eT71mAvVzMqCL7Izy08uXUYLydwUpwsgItYDpxWvJLPB8a0pzfKTSyAciIhfF70SszwN+taUZgbkNstok6RrgTpJZwLzyCxrYSnybJrjDfrWlGYGgCKy35VS0tuABcAHk02twFciYm+RaztOS0tLrFmzpv83Vrljl6yGzC/hWrjC2EFoNnCS1kZES3/v67PLSNJDycNPRsSCiDg7+ftSGmFgR9XqbBpPKzUrrmxjCO+V9E7gE5LeIemk7n+lKtCOV6uzaWo1CM1KJdsYwmIyU01PB9aSuRbhsEi2WwpqdcnqWg1Cs1Lps4UQEX8TEWeRudfx6RExrtufwyBFtTqbxtNKzYorlyuVP12KQix3tbpkda0GoVmp5LTaqZWfWlyy2tNKzYrLgWAVpRaD0KxUcrlS2czMaoADwczMAAeCmZklHAhmZgY4EMzMLJFKIEj6iKQ2SYck9bvgkpmZFV9aLYRNwCzguZSOb2Zmx0jlOoSI2Awgqb+3mplZiZT9GIKkuZLWSFqzY8eOtMsxM6taRWshSHoWOKWXlxZExLdz3U9ELAGWQOYGOQUqz8zMjlG0QIiIi4q1bzMzK7yy7zIyM7PSSGva6R9LagemA9+V1JpGHWZmdlRas4yeAJ5I49hmZtY7dxmZmRngQDAzs4RvkGP9Wrmuw3cpM6sBDgTLauW6Dm59fCN79h8EoKNrD7c+vhHAoWBWZdxlZFktat1yJAwO27P/IItat6RUkZkViwPBsnqja8+AtptZ5XKXkWU1ZkQjHb18+Y8Z0ZhCNaXjcROrRW4hWFbzZ06gsb6ux7bG+jrmz5yQUkXFd3jcpKNrD8HRcZOV6zrSLs2sqNxCsKwO/you91/LhfxFn23cpNzO26yQHAjWryunNZX1F2GhZ0J53MRqlbuMrOIVeiZUX+Mj1T5uYuZAsIpX6F/05TZusnJdB+ffs4pxt3yX8+9Z5bEMKxp3GVnFK/RMqHIaN/GFgVZKDgSrePNnTujxpQn5/6Ivl3ETD3BbKTkQrOKV0y/6QvMAt5WSA8GqQrn8oi+0Wr0w0NLhQWWzMlZuA9xW3dxCMMtDsZe4qObuMCs/DgSzQSrVDKBq7Q6z8uMuI7NB8tLgVm0cCGaD5BlAVm0cCGaD5CUurNo4EAzw8giD4RlAVm08qGxeHmGQKmEGkG/0YwPhQDAvj5CHcp4B5KC3gXKXkXlwtEp5FpQNlAPBPDhapRz0NlCpBIKkRZJekbRB0hOSRqRRh2V4cLQ6OehtoNJqITwDTIyIycBPgVtTqsPI9Cd/ddYkmkY0IqBpRCNfnTXJ/cwVzkFvA5XKoHJEfK/b09XA1WnUYUeV8+CoDU4lzIKy8lIOs4w+ATycdhFm1chBbwNRtECQ9CxwSi8vLYiIbyfvWQAcAFZk2c9cYC7AqaeeWoRKzcwMihgIEXFRttclfQz4MHBhRESW/SwBlgC0tLT0+T4zM8tPKl1Gki4BvgBcEBG706jBzMx6SmuW0QPAcOAZSeslLU6pDjMzS6Q1y2h8Gsc1M7O++UplMzMDHAhmZpZwIJiZGeBAMDOzhAPBzMwAB4KZmSUcCGZmBjgQzMws4UAwMzPAgWBmZgkHgpmZAQ4EMzNLOBDMzAxwIJiZWcKBYGZmgAPBzMwSDgQzMwMcCGZmlnAgmJkZ4EAwM7OEA8HMzAAHgpmZJRwIZmYGOBDMzCzhQDAzM8CBYGZmCQeCmZkBDgQzM0ukEgiS7pS0QdJ6Sd+TNCaNOszM7Ki0WgiLImJyREwFvgN8OaU6zMwskUogRMRb3Z4OBSKNOszM7KjfSevAku4CZgO/Bv4wy/vmAnOTp7+VtKkE5aVlJPCrtIsoomo+v2o+N/D5VboJubxJEcX5cS7pWeCUXl5aEBHf7va+W4GGiLgth32uiYiWApZZVnx+lauazw18fpUu1/MrWgshIi7K8a3/G/gu0G8gmJlZ8aQ1y+jMbk8vB15Jow4zMzsqrTGEeyRNAA4BrwOfyvFzS4pXUlnw+VWuaj438PlVupzOr2hjCGZmVll8pbKZmQEOBDMzS1RcIFTzsheSFkl6JTm/JySNSLumQpL0EUltkg5JqpopfpIukbRF0quSbkm7nkKS9KCkX1br9T+Sxkr6d0mbk/83b0q7pkKR1CDpRUk/Sc7t9n4/U2ljCJLefvhKZ0nzgHdHRK6D0mVN0geBVRFxQNK9ABHxhZTLKhhJZ5GZSPAPwJ9HxJqUS8qbpDrgp8DFQDvwEnBNRPxHqoUViKQZwC5geURMTLueQpM0GhgdES9LGg6sBa6shv9+kgQMjYhdkuqBHwE3RcTqvj5TcS2Eal72IiK+FxEHkqergeY06ym0iNgcEVvSrqPAzgFejYjXImIf8C3gipRrKpiIeA54M+06iiUitkfEy8njncBmoCndqgojMnYlT+uTv6zflxUXCJBZ9kLSNuA6qndhvE8A/5Z2EdavJmBbt+ftVMkXSq2RdBowDXgh3UoKR1KdpPXAL4FnIiLruZVlIEh6VtKmXv6uAIiIBRExFlgBfCbdagemv3NL3rMAOEDm/CpKLudXZdTLtqpptdYKScOAx4DPHtMLUdEi4mCyqnQzcI6krN1+qS1ul001L3vR37lJ+hjwYeDCqLQBHgb0365atANjuz1vBt5IqRYbhKR//TFgRUQ8nnY9xRARXZK+D1wC9DlBoCxbCNlU87IXki4BvgBcHhG7067HcvIScKakcZJOAD4KPJlyTZajZOB1GbA5Iv4y7XoKSdKowzMVJTUCF9HP92UlzjJ6jMxSrkeWvYiIjnSrKgxJrwK/C3Qmm1ZXywwqAEl/DPwtMAroAtZHxMx0q8qfpA8BfwXUAQ9GxF0pl1Qwkr4JfIDM8tD/BdwWEctSLaqAJL0P+CGwkcx3CsAXI+Kp9KoqDEmTgX8h8//lEOCRiLgj62cqLRDMzKw4Kq7LyMzMisOBYGZmgAPBzMwSDgQzMwMcCGZmlnAgWE2RNC9Z2XLAV4FLOk3StcWoK9n/DEkvSzog6epiHcesLw4EqzX/C/hQRFw3iM+eBgw4EJIVUXPxf4E5ZK7ANys5B4LVDEmLgdOBJyV9TtLQZL3/lyStO7zeUtIS+GHya/1lSX+Q7OIe4P3JvTg+J2mOpAe67f87kj6QPN4l6Q5JLwDTJb1X0g8krZXUmiy73ENE/DwiNnD0AimzkirLtYzMiiEiPpUsD/KHEfErSXeTuf/EJ5JL/F+U9CyZlSEvjoi9yVIp3wRagFvI3MfhwwCS5mQ53FBgU0R8OVkr5wfAFRGxQ9KfAHeRWdHWrGw4EKyWfRC4XNKfJ88bgFPJLE73gKSpwEHg9wex74NkFkyDzFIrE4FnMkvnUAdsz6Nus6JwIFgtE3DVsTftkbSQzLo9U8h0q+7t4/MH6Nnt2tDt8d6IONjtOG0RMb0QRZsVi8cQrJa1AjcmK14iaVqy/URge0QcAq4n84seYCcwvNvnfw5MlTRE0lgyd0/rzRZglKTpyXHqJb2noGdiVgAOBKtld5K5reAGZW4if2ey/e+Bj0laTaa76DfJ9g3AgeSm5Z8Dngf+k8xKmfcDL/d2kOTWmlcD90r6CbAe+INj3yfpbEntwEeAf5DUVpjTNMuNVzs1MzPALQQzM0s4EMzMDHAgmJlZwoFgZmaAA8HMzBIOBDMzAxwIZmaW+P/m3gMCrwIYqAAAAABJRU5ErkJggg==\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "plt.scatter(X_test[y_test==0, 0], X_test[y_test==0, 1], label='class 0', marker='o')\n", - "plt.scatter(X_test[y_test==1, 0], X_test[y_test==1, 1], label='class 1', marker='s')\n", - "plt.title('Test set')\n", - "plt.xlabel('feature 1')\n", - "plt.ylabel('feature 2')\n", - "plt.xlim([-3, 3])\n", - "plt.ylim([-3, 3])\n", - "plt.legend()\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Defining the Perceptron model" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", - "\n", - "\n", - "class Perceptron():\n", - " def __init__(self, num_features):\n", - " self.num_features = num_features\n", - " self.weights = torch.zeros(num_features, 1, \n", - " dtype=torch.float32, device=device)\n", - " self.bias = torch.zeros(1, dtype=torch.float32, device=device)\n", - " \n", - " # placeholder vectors so they don't\n", - " # need to be recreated each time\n", - " self.ones = torch.ones(1)\n", - " self.zeros = torch.zeros(1)\n", - "\n", - " def forward(self, x):\n", - " linear = torch.add(torch.mm(x, self.weights), self.bias)\n", - " predictions = torch.where(linear > 0., self.ones, self.zeros)\n", - " return predictions\n", - " \n", - " def backward(self, x, y): \n", - " predictions = self.forward(x)\n", - " errors = y - predictions\n", - " return errors\n", - " \n", - " def train(self, x, y, epochs):\n", - " for e in range(epochs):\n", - " \n", - " for i in range(y.shape[0]):\n", - " # use view because backward expects a matrix (i.e., 2D tensor)\n", - " errors = self.backward(x[i].reshape(1, self.num_features), y[i]).reshape(-1)\n", - " self.weights += (errors * x[i]).reshape(self.num_features, 1)\n", - " self.bias += errors\n", - " \n", - " def evaluate(self, x, y):\n", - " predictions = self.forward(x).reshape(-1)\n", - " accuracy = torch.sum(predictions == y).float() / y.shape[0]\n", - " return accuracy" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Training the Perceptron" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Model parameters:\n", - " Weights: tensor([[1.2734],\n", - " [1.3464]])\n", - " Bias: tensor([-1.])\n" - ] - } - ], - "source": [ - "ppn = Perceptron(num_features=2)\n", - "\n", - "X_train_tensor = torch.tensor(X_train, dtype=torch.float32, device=device)\n", - "y_train_tensor = torch.tensor(y_train, dtype=torch.float32, device=device)\n", - "\n", - "ppn.train(X_train_tensor, y_train_tensor, epochs=5)\n", - "\n", - "print('Model parameters:')\n", - "print(' Weights: %s' % ppn.weights)\n", - "print(' Bias: %s' % ppn.bias)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Evaluating the model" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Test set accuracy: 93.33%\n" - ] - } - ], - "source": [ - "X_test_tensor = torch.tensor(X_test, dtype=torch.float32, device=device)\n", - "y_test_tensor = torch.tensor(y_test, dtype=torch.float32, device=device)\n", - "\n", - "test_acc = ppn.evaluate(X_test_tensor, y_test_tensor)\n", - "print('Test set accuracy: %.2f%%' % (test_acc*100))" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaoAAADCCAYAAAAYX4Z1AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+17YcXAAAgAElEQVR4nO3dd3hUVfrA8e/JZEJ6AkloCSH0lhBCQg02UGFXqVKliKgg1bb4w2VXEBsu665KDwsiRQUVAugqrsu6LlUSepUmkoCCoUpCv78/kmiATDKT3Jl7Z+b9PA/PIyl3XuKcvPec9z3nKk3TEEIIIczKx+gAhBBCiJJIohJCCGFqkqiEEEKYmiQqIYQQpiaJSgghhKlJohJCCGFqvka8aGRkpBYXF2fESwuhi8zMzJ81TYsyOo5CMqaEJ7A1rgxJVHFxcWRkZBjx0kLoQil11OgYipIxJTyBrXElS39CCCFMTRKVEEIIU5NEJYQQwtQMqVEJIZzv6tWrZGVlcenSJaNDMRV/f39iYmKwWq1GhyLs5FGJ6j/7TvLf704x/oFGWC0yWRTeLSsri5CQEOLi4lBKlekaFy9f48zFK1SvGIBPGa9hJpqmkZOTQ1ZWFrVq1TI6HGEnj/ptvu3YWeav/57+czZx6sJlo8MRwlCXLl0iIiKizEkKIPfKdU7nXuHIzxe5dv2GjtEZQylFRESEzDLdjEclqmfuq8/bfZuxI/ssnaeuZduxs0aHJIShypOkAKJCKlCjUiC5V65z6NRFLl+7rlNkxinvz0S4nkclKoCuzaL5ZHhbfC2K3rM2sHTzMaNDEsKtVQz0o3ZkENdu3ODQyYtcvHytXNebOHEif/3rX3WK7maZmZkkJCRQt25dxowZgzxvzzN4XKICaFI9jFWj2tGyViWe/2QHf07fxZVr7r9sIYRRgir4UjcqGIsPHP75ImdzrxgdUrGGDx9OWloaBw4c4MCBA3zxxRdGhyR04JGJCqBikB/zH23BsDtrs3DjUR6es5GTF2RdWghb0rdmkzp5DbXGfUbq5DWkb82+6fMVrBbqRAUTaLXww+lcTp6/VOqMZcGCBTRt2pTExEQGDhx42+fnzJlDixYtSExM5KGHHiI3NxeAjz76iPj4eBITE7nzzjsB2L17Ny1btqRZs2Y0bdqUAwcO3HStEydOcP78edq0aYNSikGDBpGenl6eH4kwCY9NVAC+Fh9e+H0jpvZLYvfx83SeupatP5wxOiwhTCd9azYvLNtJ9tk8NCD7bB4vLNt5W7LytfhQKyqI8EA/fjx/iawzedywkax2797Nq6++ypo1a9i+fTtvv/32bV/To0cPNm/ezPbt22nUqBFz584FYNKkSaxevZrt27ezcuVKAGbNmsVTTz3Ftm3byMjIICYm5qZrZWdn3/SxmJgYsrNvjl+4J49OVIU6J1Zn2Yi2+Pn60Gf2RpZs/sHokIQwlSmr95N39eZGibyr15myev9tX+ujFDUqBlAl1J8zuVf43kZH4Jo1a+jZsyeRkZEAVKpU6bav2bVrF3fccQcJCQksXryY3bt3A5CamsrgwYOZM2cO16/nx9WmTRtee+013njjDY4ePUpAQMBN1ypudieNE57BKxIVQKNqoawa1Y5WtSvxf5/sZPzynVK3EqLA8bN5Dn1cKUWVUH9qVAzkoo2OQE3TSk0UgwcPZtq0aezcuZMJEyb82jY+a9YsXnnlFY4dO0azZs3Iycnh4YcfZuXKlQQEBNCxY0fWrFlz07ViYmLIysr69e9ZWVlUr1691H+7MD+vSVQA4YF+zH+0JcPvrsPiTT/Qb85GTp6XupUQ1cMDHPp4oYpBftSy0RHYoUMHli5dSk5ODgCnT5++7fsvXLhAtWrVuHr1KosXL/7144cOHaJVq1ZMmjSJyMhIjh07xuHDh6lduzZjxoyhS5cu7Nix46ZrVatWjZCQEDZu3IimaSxYsICuXbva/TMQ5uVViQrA4qP4v04Nmf5wc/YcP8+DU9eSeVTqVsK7je3YgACr5aaPBVgtjO3YoNTvDa7gS52oYHxu6Qhs0qQJ48eP56677iIxMZFnn332tu99+eWXadWqFffddx8NGzb8LZ6xY0lISCA+Pp4777yTxMRElixZQnx8PM2aNWPfvn0MGjTotuvNnDmTxx9/nLp161KnTh1+97vfOfqjECakjNhnkJKSopnh2Tn7fjzPsIWZHD+bx0td4nm4VazRIQk3oZTK1DQtxeg4ChU3pvbu3UujRo3svkb61mymrN7P8bN5VA8PYGzHBnRLirb7+69dv8HRnFwuXrlG1VB/okIqmLZG5OjPRriGrXHlUWf9Oaph1VBWjmzHmA+38sflO9mZfZaJXZpQwddS+jcL4WG6JUU7lJhu5WvxoVZkEFln8vjx/CWuXLvhMWcECmN53dLfrcICrcwb3IIRd9fhg2+P0TdtIz9J3UqIMvHxUdSoFEDlEH9OF3YE3pCmJVE+Xp+oIL9u9Xynhszo35z9P17gwalryfj+9sKvEKJ0SimqhvkTU9gRePIiVzzgjEBhHElURfw+oRrpI1MJ8rPQb85GFm08KmeFCVFGlYL8qBURyLUbNzh48iK55TwjUHgvSVS3qF8lhBWj2tGubiR/St/FuE92esSJ0UIYIdjfelNH4DmTnhEozK3ciUopVUMp9R+l1F6l1G6l1FN6BGaksAArcx9pwej2dVmScYw+szfy4zmpWwlRFv5WC3WjgvG3Wjh6OpeTF0o/I1CIovSYUV0DntM0rRHQGhiplGqsw3UN5eOjeO7+BswakMyBn/LrVt8ekbqVcD5PvPnztfhQOzKI8AArEyZM5M8vv27zjMDyGD9+PDVq1CA4OFj3awvjlDtRaZp2QtO0LQX/fQHYC5S9x9VkOsVXJX1kKqH+vjw8ZyMLNnwvd4PC2Tz25q9GpUCC/CzkXbnO9z9f5LrOHYGdO3fm22+/1fWawni61qiUUnFAErCpmM8NVUplKKUyTp06pefLOl29KiGkj0rlrvpRvLhiN89/vINLV6VuJZzDkJu/16JhYtjtf14r38ve+pgPpRTB/lbCAq1cvHyd1/8+nZQUfR7zAdC6dWuqVatWrpiF+eiWqJRSwcAnwNOapp2/9fOapqVpmpaiaVpKVFSUXi/rMqH+VuYMSmFMh3p8lJlFn9kbbB7YKYReSrr509WVXxz7uB1KesxHoJ8vtSIDuafjgyxY+W82bM4s92M+hOfSJVEppazkJ6nFmqYt0+OaZuTjo3j2vvqkDUzm0KmLdJm2lk2Hc4wOS3io0m7+zL5KUdpjPoL9reT+dIRB3TuR2DSRhYsWlesxH8Jz6dH1p4C5wF5N0/5W/pDM7/4mBXWrACv9/7GJ+euOSN1K6Mqemz+zr1LY85iPoY8/xqyZ0/n8m008PuZ5zl64iKZpZXrMh/Bcepz1lwoMBHYqpbYVfOyPmqb9U4drm1bdysGkj0zl2SXbmbhqDzuzz/Nq93j8rV56TuBr0cUvE/kFwx/lKauO8JSbvw4dOtC9e3eeeeYZIiIiOH369G2zqgsXLhAbE01YmB9frvyYSlFVyT6bx6Wc47Rq1YpWrVqxatUqjh07xrlz5359zMfhw4fZsWMH7du3N+hfJ1yp3IlK07S1gFeeOhnqbyVtYDJT1xzk7199x3c/XWDWwGSiS3mGj1uzlZBsufJL/vdIsnKER9z8FX3Mh8ViISkpifnz59/0NYWP+ahZsyaJ8fGcOn2O0xevMPbpZzl+NH+lokOHDiQmJjJ58mQWLVqE1WqlatWqvPjii7e95vPPP8/7779Pbm4uMTExPP7440ycONE1/2DhNF79mA89fbXnJ55Zsg2rrw/TH25OmzoRRofkHBPDyvh95/SNw2Ae95gPE82IT1+8TPaZS1Sw+hAXEYSfr/4H6Bj2mA8T/ZzNSB7z4WT3Nq5C+qhUhi7IYMDcTYz/fSMeTY0z7fN4XK64BCeD0zxM9P+hUlAFrBYffsjJ5eCpX4iLCCTQz0N+VTmhu9IbyFl/OqoTlV+36tCwMpM+3cNzS7fLfquSyOAUNoT4W6lTORgf4PCpi5zLu2p0SMJAkqh0FuJvZdaAZJ67rz7Lt2Xz0Mz1ZJ3JNTosIdyOv9VCncoFZwTmXOTUhcvSXeulJFE5gY+PYnSHesx9JIUfcnLpPHUt6w/+bHRYwgu5+y92a8EZgWEBVk6cy+P42fIfaOvuPxNvJInKido3rMKKUalEBldg4Lxv+cf/Drv/IPGzcdinX7DtzwlD+Pv7k5OT4/bvOR8fRWylQKJCKpBz8TLf5+Ry/UbZ/k2appGTk4O/v7/OUQpn8pAKpXnVjgpm+chU/rB0O698tped2eeY3KMpAX5uut+qLEX3snYKinKJiYkhKysLM55aUVZXLl/jUO5VjloUkcEVsPg43qzk7+9v3PFLfsG2u/6kI9AmSVQuEFzBl5kDmjPj60P89cv9HPjpF2YPTKZGpUCjQ3MNW4MTpBvQiaxWK7Vq1TI6DN3997tTjFy8hUA/C/MGtyA+2o1uhEp6X9u6oTNj05GLk6os/bmIUoqR99Rl3iMtOHYmly7T1rL2gJfUrf6Ynb+P6tY/tphxYArTuKt+FJ8Mb4vV4kOvWRv4as9PRofkfVzcZi+JysXuaViZVaPaERVSgUHzNjHnGw+oWwnhYg2qhrB8RFvqVQnmiYUZvLvuiNEhiUI6PiamkCQqA8RFBrF8RCqd4qvy6j/3MubDbeReuWZ0WEK4lcqh/nw4tDX3NarCS6v2MGHFLq5d1/dBjKKcdJphSaIySFAFX6Y/3JznOzXg0x3H6TFjPcdOy34rIRwR6OfLzAHJPHFHLd7bcJShCzO5eFlu+jyNNFMYSCnFiLvr0rhaKGM+2ErnaWuZ2i+JO+qZ75ENHks6rdyexUcx/oHGxEYEMWHFLnrN2sDcwSlUC3Ozw6FL6gj0cpKoTODuBpVZNbodwxZm8si8b3m+U0OG3Vnb888JdKQbsKyJo7REJGeveYyBrWsSUzGAUYu30G36OuY+4gYdge56o1TS2HUCSVQmUTMiiGUj2jL24x1M/nwfO7PPMaVnU/c9jNOeAVjcQNS7RVcSkVe5p0FlPh7eliHzN9N79gam9kuiQ6MqRodlm7u+P28du07eK+mmvwU9U6CfL9P6JZEQHcZfvtjHoZP5+61qRgQZHZrj3HUACrfXqFoo6SNTeey9zTyxIIMJnZvwSNs4o8Ny/Flu7sTJy5aSqExGKcWTd9WhcbVQRn+wlS7T1vFOvyTuqm9g3aqk2RF47uATbqtKqD9Lh7VhzAfbmLByN9/nXORPDzQu00kWuvHkceLkZUpJVCZ1Z/0oVo1qx9CFGQx+91vGdmzA8LvqGFO3ktmRcEOBfr7MHpjMq5/tZd66Ixw7ncvbfZMIquDFv/bctCYm7ekmFhsRyLIRbXmwaXX+8sV+Rr6/RVpv9VbSIbvC7Vl8FC92bsykrk1Ys+8kvWdv4Kfzl4wOyzhuetPpxbcW7iHQz5d3+jajaXQYr3++l4MnfyFtYApxkW5Yt7KH3mvdpV3PxHeRQj+D2sRRo2Igo97/rSOwcfVQo8OyTW6UbiKJyg0opXjizto0qhbKqA+20GXaWt7ul8Q9DSrr9yKl1aH0Utr19E4ckohEgXsaVmbpk214bH4GvWatZ1r/5vqOofIo6exLe7jpkp69dElUSql5wIPASU3T4vW4prhdu3qRrBqVv99qyPzN/OH+Boy4W6e6lTOWBMo7+Ap5+CAUrtOkehjpI1MZMn8zj83fzEtd4xnYuqbRYZWfmy7p2UuvGtV8oJNO1xIlqFEpkE+Gt6VLYnWmrN7P8EVb+MXZdauyzKr0nIl5+CAUrlU1zJ+PnmzD3Q0q8+f0Xbzy6Z4yP4hRuIYuMypN075RSsXpcS1RugA/C2/1aUZCdBivf76P7tPXkTYohVrOqls5sjEX9JtJCeEkQRV8mTMohZc/3cM/1h7h6Olc3u7bzLkb7M1wRJIZYigDl9WolFJDgaEAsbGxrnpZj6WU4vE7atO4Wigj3y+oW/VtRvuGLtqF76ZveCEKWXwUE7s0oWZEIC9/uoc+szcy95EUKoc66TH1Ri5Tu/nyucsSlaZpaUAaQEpKisyzddK2buSv5wQ+9l4Gz9xbn1H31MVHz42Ntp7CKzMnp5G6r+s8mlqL2EqBjP5gK92mr2Peoy1oWNXEHYFlUdryuckTmXT9eYCYivl1qxeW7eRv//qOXdnneLN3IiH+Vvsv4ughk8V9rcnf7G5mPjANWGBwHF6hQ6MqLB3Whsfe20zPmRuY3r+5vqfBOHtslHeFw+R1YElUHsLfauFvvRNJiA7j1X/upVtB3apOlJ1vVFuDxZHDJk3+ZncnUvd1vfjowo7ADIbM38ykrk3o30qnjkBnjw0PvxHUqz39A+BuIFIplQVM0DRtrh7XFvZTSjGkXS0aFdStuk1bx9/7NOPexk6qW+l5yKar9nF5EKn76q9aWAAfPdmG0e9vYfzyXRzNyWVcp4b6LqULh+nV9ddPj+sIfbSpE8Gq0e14cmEmjy/I4Ol76zGmfT39B5ueMyWZjTlM6r7OEVzQETjp0z2kfXOYozkXeatPEgF+Fue/uCyfF0vO+vNQ0eH5d4Y9mkfz1lcHGLowk/OXrhodlhBuwdfiw0tdmvDig435cs9P9E3bwMkLLjgj0Fk3bG5+pqXUqDyYv9XCm70SaRodxsufFdStBqZQt7IDb86SirRGz3bcZJAJ91S4lB5TMYCnPtxG9+nrmTe4BQ2qhhgdmuNKm42ZfLuJJCoPp5RicGotGlYLZWTBI7r/1juR+5tUtf1N9i4/2Ntooceb3cta4aXuax73N6nK0mFtGPLeZnrOXM+MAc25o56DHYFmSwRutsQoicpLtK6dX7caviiToQszGdOhHk93sFG30mP5wcsSi96k7msuCTH5HYGPzd/M4Hc380q3ePq1dKCBxWy//N2sJiw1Kg+UvjWb1MlrqDXuM1InryF9a/4gqR4ewJJhbeiZHMM7/z7AEwsyOJdXjrqVnuveJlliEMKWwrpvu7qRvLBsJ69/vpcbckagS8iMysOkb83mhWU7ybt6HYDss3m8sGwnAN2SovG3WpjSsymJMWG8tGpPQd0qmXpVyrDubs9dor1LDH/MLrnd/dZlRpMuUQjPFuJvZe4jKUxYuZvZ/z3MDzm5/L1PM/ytOnUEmm2J0CQkUXmYKav3/5qkCuVdvc6U1fvplhQN5NetBraJo0HVUEYszqTb9HW82bsZneJLqFsV5cj6tiNLDI4cfmvSJQrh+XwtPrzSLZ5akUG8+s+9nEjbyJxBKUSFVCj/xeXmq1iSqDzM8bN5dn+8Za1K+futFm3hyUWZjG5fl6fvrU+p94YlJR9HTrIQwk0VHgodUzGQp5dspfuMdbw7uIV9KxNu1shgBlKj8jDVwwMc+ni1sACWDG1N75QYpq45yOPvbeac1UZHk5cvPwhxq07xVVkytA2Xrt6gx8z1rDv4c+nfZIZGBjfbVyUzKg8ztmODm2pUAAFWC2M7NrD5Pf5WC2881JSmMeG8tGo3XcPnkPZ4CvXLUrcSwssk1ggnfWRbhszfzCPzvuW17gn0blHD6LBK5mYzN6+YUdnqgvNE3ZKieb1HAtHhASjyO5Ve75Hwa33KFqUUA1rX5IMnWvPL5et0m76Oz3eecE3QQri5mIqBfDy8LW3qRPD8Jzv4yxf7pCNQRx4/oyqtC84TdUuKdujflr41mymr93P8bB7VwwMYdU8dVmw/zvDFWxhxdx2eu78BFr3PCSxcYiip088vuOQTMIrWw2R9Xxgs1N/KvMEteHHFbmZ8fYijObm82TtRv45AL+bxicqeLjhvVlwif+OL/bzctQkNq4Yw4+tD7D5+nnf6JhEWWPB8Kz2OT7ryS+mnr1/55faNw9IFKEzMavHhte7x1IoM5LV/7uP4uTzmDEohMliHjkAv5vFLf450wbmaGZYkbSXyv391gNd7NOW17gmsP/QzXaavZf+PF/K/QK+ZiyQX4YGUUgy9sw4z+zdnz/HzdJ+xjoMnL/z2BW7WyGAGHj+jqh4eQHYxSclWF5yrmGVJsrRE/nCrWBpUDWH4oky6z1jHlJ6JPNC0WskbE29NZNKyLrzQ7xKqUTXMnycWZNBjxnpmDUymbZ1IWaIuA4+fUY3t2ICAW9aIS+uCc4WSliRdyZ529uSaFVk1uh0Nq4Yw8v0tTP58H9fHZeUvy936B/ITU9E/QnippNiKLB+RSpVQfwbN/ZaPMo4ZHZJb8vhEVdYuOGczy5KkvYm8Sqg/Hw5tQ/9Wscz67yEGv/stZ3Ov3H5BWc4T4iY1KuV3BLauHcHYj3fw19X7pSPQQR6/9AeOd8G5glmWJAt/LkW7/sZ2bFDsz8vP14dXuyeQEB3Giyt203naWtIGptCoWmjZAyipMaO4NXs5C024obAAK+8+2oI/Ld/FtP8c5OjpXKb0bCodgXbyikRlRmXZmOssjibyvi1jqV9Qt+oxYz1v9GxKl8Tqjr9wWVrKZX1fuCmrxYfJDyUQFxnEG1/s4/jZPNIGJhMhHYGlkkRlkOJmMvc0jGLK6v08s2RbiTMbM2gem1+3GrFoC2M+2Mru7HOM7dig5DeUPc+oknPQhAdTSjH87jrEVgrkmaXb6DEz/6nBdaJkRaAkkqgMVHQmY5YuQEdUDvHn/SdaM+nT3cz+5jC7j59nqhZMRVWOOpUZzkETwskeaJrfETi0oCNw9sBkWteOMDos09KlmUIp1UkptV8pdVApNU6Pa3obs3QBOsrP14dXuiXwxkMJfHvkNJ2vvMbuGzWL+UK5YxSiqOSa+R2BkcF+DJy7iU8ys4wOyT6vRd/e2TsxLP/jTlLuGZVSygJMB+4DsoDNSqmVmqbtKe+1vYlZugDLqk+LWBpUDeXJhZk8lPcGb3RrStdm5pwJCmEWsRGBLBueyvDFmTz30XaOns7lmXvroZTOR5bpyYBVDz1mVC2Bg5qmHdY07QrwIdBVh+t6FUcfz2FGzWqEs2p0O5pGh/PUh9t49bM9XLt+w+iwhDCngplJ2F8imZ/1e3pZvuadfx/g6T9P4PK166V/vxfRI1FFA0V3sWUVfOwmSqmhSqkMpVTGqVOndHhZc3P0eCSzbkx2VFRIBRY/0YpH2tRkzv+OMGjet5y+WMx+KyG8XZEZiJ+6zl980xjru4QV11ox4B+bZNwUoUeiKm6OettuNk3T0jRNS9E0LSUqysaD+TxEYWNE9tk8NH5rjCgpWbliY7Krzha0Wnx4qWs8U3o2JePoGTpPXcuubDs6/kDOQRMew9HxphSM9F3BVOs7bM86R48Z6zjy80UXRWtuenT9ZQFFnxIWAxzX4bpuq6wntjtzY7IRXYW9UmrQoGoIwxZm8tDM9bzxUNPSX0ta0IUHKM9462zZSPVHW/HEgvzzNWcPSKaVl3cE6jGj2gzUU0rVUkr5AX2BlTpc122ZsTHCqK7CpjH5davEGuE8vWQbk1ZJ3coe0knr3so73pJrVmL5iLZUCvJjwNxNLN9qoo5AA1Y9yj2j0jTtmlJqFLAasADzNE3bXe7I3JhZjkcqysjkGRlcgcWPt+LVz/Yyb90R9pw4x/SHm8uOfBukk9b96THeakYEsXx4KsMWZfDMku0czcnlqQ4m6Ag0YNVDl31Umqb9U9O0+pqm1dE07VU9runOytMY4aw6ktFdhVaLDxO7NOHNXols/eEsnaeuZWeWnXUr7yOdtG7OrvFmx8wkLNDKgiGteKh5DG99dYBnl273yo5AOZnCCRw56LUoZ9aRzHK24EPJMdSvEsKwhRn0nLWe13sk0KN5jEtjcAPFddK2uvWLlFJDgaEAsbGxrolM2MWu8WbnzMTP14e/9mpKXEQgb/7rO7ILzggMD/TTO2zTUprm+uPmU1JStIyMDJe/rr3St2Y7nGT0kDp5TbFLhtHhAawb177c1zfq31WcnF8uM/L9LWw8fJrBbeMY/0AjrBb3eeqMUipT07QUJ127F9BR07THC/4+EGipadpoW99j9jHljZwx3lZsy2bsRzuIrhjAu4NbEBcZpFO05mBrXMmM6hZGnrnn7DqSmR53EhFcgUWPteL1z/cxd+0R9p44z/T+zYmUuhVIJ61HcMZ469osmurhAQxdkEH3GetIG5RCi7hKur6GGbnPLawLpG/N5rml2w07c8/oOpKr+Vp8+PODjXmrTzO2HcuvW+3IOmt0WGYgnbTCphZxlVg+IpXwQD/6z9nEim2ev6VDZlQFCmdS120shbqiO84sdSRX65YUTd3KwQxbmEnPWRt4tVs8vVJqlP6NHsqTO2nNtPxsWnY86iYuMohlw9sybFEmT324jR9ychnVvq7xHYFOYtoZlatOUShU3L6Holwxq3HF6RRmFR8dxqrR7UipWZGxH+9gwopdXPXi/Vae2ElblhNbvJKdh75WDPJj4WMt6Z4UzZv/+o4/fLSDK9c8c8yYckZlRJ2opBmTK2c1ZqojuVqlID8WDGnJG1/sY87/jrD3xAWm929OVIjUrTxBWU9sEbZV8LXwt96J1IwI5K2vDpB9NpfZA1IIC7QaHZquTDmjMuIUBVszJotSXjOrMQNfiw/jH2jM232bsSM7v2617ZjUrTyBGU9s8QRKKZ6+tz5/75PIlqNn6T5zHUdzPOuMQFMmKiPe0LY26b7ZO1GSlAG6Novmk+Ft8bUoes/awNLNx0r/JmFq3tYs5Grdk2JY+FhLTl+8QvcZ68k8etrokHRjykRlxBvam+tDZtWkehirRrWjZa1KPP/JDv6UvtNj1+C9gac8ysbMWtWOYNnwtoT6+9JvziZWbfeMXQ2mrFEZ1f3mqvpQYedT9tk8LEpxXdOIlg6oYlUM8mP+oy2Ysno/s785zL4TF5gxoDmVQ/yNDk04qKwntngdv2DbXX92qB0VzLIRqQxbmMHoD7byw+lcRtxdx607Ak17MoUZ21j1iOnWRpGirBZFkJ8v5/KumuLfbLb/Byu3H+f5j7cTFmBl5oBkmsdWNCwWZ55MURZyMoV9zPaedqbL167z/Mc7WLHtOL2SY3i1ewJ+vqZcRPuV251MYX8bgwUAAA5XSURBVLbuN706EUtqg796XeNs3tXbrl/4fa4cXEae0GFLl8Tq1I0KZtiiDPrO3sikrk3o21LOuBP2MeN72pkq+Fp4q08zakYE8c6/D5B9No+Z/ZPdsiPQ3OnVRPTqRHSkISTv6nUmrtxtyN4To55fVZrG1UNZNaodrWpXYtyynYxfLnUrYR+zvqedSSnFs/fV581eiWz+/jQ9Zq7jh5xco8NymCQqO+nViehoQ8jZvKuGDC4ztxKHB/ox/9GWPHlXHRZv+oF+czZy8vwlo8MSJmfm97SzPZQcw8LHWvHzL1foPmMdW344Y3RIDpFEZaeydCIWd7pGcZ1PZeHswWX2VmKLj2Lc7xoy7eEk9hw/z4NT13pUO67Qn9nf087WunYEy0a0JaiCL/3SNvLZjhNGh2Q3SVR2crS11tZxMcCvbfClCbBaqGhjPblwcKVvzabZS18SN+4z4sZ9RtKkL3VZFnSXVuIHm1Zn+ci2BPhZ6Ju2kfc3/WB0SMKk3OU97Ux1ooJZPqIt8dFhjHx/CzO/PoQRDXWOMm3Xnxk50jFk69lS4QFWtk24v9SvsyjFm70TAYpt1X+9RwIAYz/aztUbN/8/tFoUU3qWf6OyO3VIncu9yugPt/LNd6fo17IGE7s0oYJv+WeutkjXXz53eo+A+8XrLJeuXucPH23n0x0n6NuiBi93izfF8+DcruvP3dlamjubd5X0rdm/Dg5be8Zu3Wxc3OBKnbzmtiQF+d2DepyfZrbOy5KEBVp5d3AL3vxyPzO+PsS+Hy8wa0AyVUJlv5WzmL2LzlZSMkNsRvO3WninbxJxEUFM+89Bss7kMb1/c8ICzNkRKInKTo4OyurhAcXOlICbkog9myBtDa6S6lTeUCC+lcVH8XynhsRHh/GHj7bz4NS1zOzfnBQveLCcEcx8yKzZk6gZ+Pgo/tCxATUjAnlh2U56zlzPvMEtqFEp0OjQbmP8XM9NONraWtK6961JpFtSNOvGtefI5AdYN6693QOppCKwtxSIi/P7hGqkj0wlyM9CvzkbWbTxqFusw7sbM3fReWMreln1SqnBgiEt+en8JbrPWGfKQ6DLlaiUUr2UUruVUjeUUqZZr3cGRwdlt6ToUhshymtsxwZYfW4/FsVqUV5VIC5O/SohrBjZjtS6kfwpfRfjPtnJpRKeNyYcZ+YuOjMnUTNqWzeSZSPym5L6zN7A5zvN1RFY3hnVLqAH8I0OsZhaWQblhM5NSu0yKs8DIrslRTOlVyLhRdaVKwZa6dOiBlNW77/pmq5+EKUZhAVamftIC0a3r8uSjGP0SdvIiXPyi0ovZu6iM3MSNau6lUNYPiKVxtVDGfH+Fmb/1zwdgbp0/Smlvgb+oGmaXW1H7tihVNwZfcU1PTjymmW9pqNxWn0UqPwmi+Jexxs6ob7YdYLnlm4nwM/CjP7JtKxVvrqVdP3lM+t7xxljy1tcunqd5z7azmc7TtCvZSyTujZxWUeg13f9lbe4WtaTn0vqMnJGMbq4axbXGVh0vd4bis6d4qtRJyqYoQszeXjORl7s3JiBrWu69YnSZmDWLjo5qb3s/K0WpvZNomalQGZ8fYisM7lM79+cUH/jOgJLnVEppb4CqhbzqfGapq0o+JqvKWVGpZQaCgwFiI2NTT569GhZYy4TW/uVosMDWDeuvVNes7S7zVrjPqO4n74Cjkx+oEyvaeuaxVHY7k505s/FSOfyrvLskm38e99JeiXH8HK3ePzLcFKIzKiEN1i6+Rh/XL6T2lFBzBvcgpiKzu0ILPOMStO0e/UIQNO0NCAN8geVHtd0hLOKq+lbs3lp1W7O5Oafeh4eYGVilyZA6TMVW0miPOvoJbXFF/e13lZ0DguwMmdQCm/9+wDv/PsA+3/K328ltQshbte7RQ2iKwbw5KJMuk1fz9xHUkisEe7yOLymPd0ZxdX0rdmM/Xj7r0kK8jf0Pr1kG08v2VZqe2xxxWhFflIra8NDcde0+iislpuXuAqL3t5YdPbxyT9RevbAZA6fukiXaWvZdDjH6LCEMKXUupEsG94Wf6sPfdI28MWuH10eQ3nb07srpbKANsBnSqnV+oSlP2d0KE1Zvf+mBgV7FJ2pdEuKvuncPwW/LtuV9XEeRa+pyF/Cm9IrkSk9E2/6WGFR2cydW87WsUlV0ke2JTTASv9/bOK99d8bHZIQplSvSn5HYMOqoQxfnMk//nfYpR2BXnXWn94dSo7UgwrZqv0YUUMrZNbOLVc5f+kqzy7ZTnLNigy/u45d3+OsGpVSqhcwEWgEtDR7J63wLpeuXufZpdv4584f6d8qlpe6NMFXx45Ar+/6A/07lBypB0HJMxUja0Vm7dxylVB/K2kDkzFJA2Dh3sTZRgcixK38rRam9WvOXyrtZ9Z/D5F1Jo9pDycR4uSOQK+pUTnD2I4Nbqv92FJ0ua043lgrMhMfH2WKVnVN0/Zqmibn/AjT8il4FtzrPRJYe/Bnes3a4PQbaklU5dAtKZopPRNtHpUE+bOot/o0K/UMP1u1onsaRnndiRJCCPPr1zKW+Y+2IPtMHt2mr2Nn1jmnvZZXLf05w63LZmWt9xS3QfGehlF8kpnt8ZtxvY09exPtvE7RvYk6RSeE/e6oF8XHw9syZP5mes/ewDv9krivcRXdX8ermincjZENFqJkzt7w6y7HkgkBcPLCJZ54L4Md2ef40wONGZIaV6aldFvjSpb+TMzbNuMKIdxT5RB/PhzahvsbV+HlT/cwYeVurl2/odv1JVGZmDRYeB932psoRFEBfhZm9k9m6J21WbDhKE8syOCXy9d0ubYkKhPz5s243krTtOWapsVomlZB07QqmqZ1NDomIezl46P44+8b8Uq3eL458DPDF2Xqcl1ppjAxOQFaCOGOBrSuSY1KgYT665NiJFGZnLdvxhWez9tPRvFUd9WP0u1akqiEEIYp73PihHeQGpUQwjAlPTxUiEIyo9KZLGMIYT/ZgiHsIYlKR7KMIYRjnPHwUD3Jjac5yNKfjmQZQwjHmHkLRuGNZ/bZPDTK/ow4UX4yo9KRLGMI4RhXbsFwdHZU0o2nzKpcSxKVjsy+jCGEGbliC0ZZluXlxtM8ZOlPR2ZexhDCm5VlWV6OMDMPSVQ66pYUzes9EogOD0BR+sMShRCuUZbZkbNvPNO3Zsuz5uwkS386c9YyhnQfCVF2ZVmWd2b9TDqEHSOJyg3Im1qI8hnbscFNYwjsmx0568ZTGjUcU66lP6XUFKXUPqXUDqXUcqVUuF6Bid9I27sQ5WO2ZXlp1HBMeWdU/wJe0DTtmlLqDeAF4P/KH5YoSt7UQpSfmQ54lg5hx5RrRqVp2peaphU+GWsjEFP+kMzP1UVQ6T4SwlzK+ztAOoQdo2fX3xDgcx2vZ0pG7FaXN7UQ5qHH7wCzLUWaXalLf0qpr4CqxXxqvKZpKwq+ZjxwDVhcwnWGAkMBYmNjyxSsGRhRBJUHKAphHnr9DjDTUqTZlZqoNE27t6TPK6UeAR4EOmiappVwnTQgDSAlJcXm15mdUfUieVMLYQ5SM3a9cjVTKKU6kd88cZemabn6hGRuUgQVonzcfU+g/A5wvfLWqKYBIcC/lFLblFKzdIjJtNK3ZnPx8rXbPi71IiHs4wknkruyZiynV+Qr14xK07S6egVidrduui1UMdDKhM5N3OqOUAijeMJGV1fVjGWj/2/kZAo7FTfAAAL9fL3uTSNEWXlKfccVNWNPSOp6kUNp7eQpA0wII8meQPvJ75zfSKKykwww4QqefiyZ7Am0n/zO+Y0kKjvJABMu8i8gXtO0psB35B9L5jFko6v95HfOb6RGZSfZdCtcQdO0L4v8dSPQ06hYnEX2BNpHfuf8RhKVA2SACRcbAiyx9UlPOe1F2Ca/c/JJohLCxfQ6lsxTTnsRojSSqIRwMb2OJRPCW0iiEsJEvPFYMiFKI11/QpiLVx1LJoQ9ZEYlhIl407FkQthLGbEErpQ6BRx10uUjgZ+ddO2ykHhsM1Ms4Fg8NTVNi3JmMI6QMWUoM8VjpljA8XiKHVeGJCpnUkplaJqWYnQchSQe28wUC5gvHrMw289F4rHNTLGAfvFIjUoIIYSpSaISQghhap6YqNKMDuAWEo9tZooFzBePWZjt5yLx2GamWECneDyuRiWEEMKzeOKMSgghhAfxyERltmf6KKV6KaV2K6VuKKUM6chRSnVSSu1XSh1USo0zIoYiscxTSp1USu0yMo6CWGoopf6jlNpb8P/oKaNjMiMZU8XGIGOq+Fh0H1Memagw3zN9dgE9gG+MeHGllAWYDvwOaAz0U0o1NiKWAvOBTga+flHXgOc0TWsEtAZGGvyzMSsZU0XImCqR7mPKIxOVpmlfapp2reCvG4EYg+PZq2nafgNDaAkc1DTtsKZpV4APga5GBaNp2jfAaaNevyhN005omral4L8vAHsBea7CLWRM3UbGlA3OGFMemahuMQT43OggDBYNHCvy9yzkl/FtlFJxQBKwydhITE/GlIwpu+g1ptz2rD+9nunjyngMpIr5mLR7FqGUCgY+AZ7WNO280fEYQcaUQ2RMlULPMeW2icpsz/QpLR6DZQE1ivw9BjhuUCymo5Sykj+gFmuatszoeIwiY8ohMqZKoPeY8silvyLP9Okiz/QBYDNQTylVSynlB/QFVhockykopRQwF9iradrfjI7HrGRM3UbGlA3OGFMemagw2TN9lFLdlVJZQBvgM6XUale+fkERfBSwmvzC5lJN03a7MoailFIfABuABkqpLKXUY0bFAqQCA4H2Be+VbUqp3xsYj1nJmCpCxlSJdB9TcjKFEEIIU/PUGZUQQggPIYlKCCGEqUmiEkIIYWqSqIQQQpiaJCohhBCmJolKCCGEqUmiEkIIYWqSqIQQQpja/wO+r9o10Eg/NgAAAABJRU5ErkJggg==\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "##########################\n", - "### 2D Decision Boundary\n", - "##########################\n", - "\n", - "w, b = ppn.weights, ppn.bias\n", - "\n", - "x_min = -2\n", - "y_min = ( (-(w[0] * x_min) - b[0]) \n", - " / w[1] )\n", - "\n", - "x_max = 2\n", - "y_max = ( (-(w[0] * x_max) - b[0]) \n", - " / w[1] )\n", - "\n", - "\n", - "fig, ax = plt.subplots(1, 2, sharex=True, figsize=(7, 3))\n", - "\n", - "ax[0].plot([x_min, x_max], [y_min, y_max])\n", - "ax[1].plot([x_min, x_max], [y_min, y_max])\n", - "\n", - "ax[0].scatter(X_train[y_train==0, 0], X_train[y_train==0, 1], label='class 0', marker='o')\n", - "ax[0].scatter(X_train[y_train==1, 0], X_train[y_train==1, 1], label='class 1', marker='s')\n", - "\n", - "ax[1].scatter(X_test[y_test==0, 0], X_test[y_test==0, 1], label='class 0', marker='o')\n", - "ax[1].scatter(X_test[y_test==1, 0], X_test[y_test==1, 1], label='class 1', marker='s')\n", - "\n", - "ax[1].legend(loc='upper left')\n", - "plt.show()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.1" - }, - "toc": { - "nav_menu": {}, - "number_sections": true, - "sideBar": true, - "skip_h1_title": false, - "title_cell": "Table of Contents", - "title_sidebar": "Contents", - "toc_cell": false, - "toc_position": {}, - "toc_section_display": true, - "toc_window_display": false - } - }, - "nbformat": 4, - "nbformat_minor": 4 -}