{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "kernelspec": { "name": "python3", "display_name": "Python 3" }, "colab": { "name": "tut10_gans.ipynb", "provenance": [], "collapsed_sections": [] }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "-Zq1dWoIup0n", "colab_type": "text" }, "source": [ "# CSC413/2516 Winter 2020\n", "\n", "# Generative Adversarial Network (GAN) Tutorial\n", "\n", "In this tutorial we'll see how to code a simple Generative Adversarial Network (GAN) to generate handwritten digits that resemble those found in the MNSIT dataset.\n", "\n", "\n", "## Resources\n", "\n", "The following are great resources for learning more about GANs.\n", "\n", "1. [Generative Adversarial Nets (Goodfellow et al.)](https://arxiv.org/pdf/1406.2661.pdf) -- the paper that introduced GANs\n", "2. [An Introduction to GANs in Tensorflow](http://blog.aylien.com/introduction-generative-adversarial-networks-code-tensorflow/)\n", "3. [Generative Models Blog Post from OpenAI](https://blog.openai.com/generative-models/)\n", "\n", "Some slides are located [here](https://docs.google.com/presentation/d/1ZNEXW-Llyxb9uAVRGlhfvJKDHTXp_pobs99SAH4ywpU/edit?usp=sharing)" ] }, { "cell_type": "code", "metadata": { "id": "NHNPF0RIup0p", "colab_type": "code", "colab": {} }, "source": [ "import itertools\n", "import math\n", "import time\n", "import matplotlib.pyplot as plt\n", "from IPython import display\n", "import numpy.random as npr\n", "\n", "# We could use numpy for most things, except computing the gradient\n", "#import numpy as np\n", "import jax.numpy as np\n", "from jax import grad\n", "\n", "# We could also swap jax for autograd\n", "#import autograd.numpy as np\n", "#from autograd import grad\n", "\n", "#%matplotlib inline" ], "execution_count": 0, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "odyXoHXsup0s", "colab_type": "text" }, "source": [ "## Load Dataset" ] }, { "cell_type": "code", "metadata": { "id": "9RLNGvRdup0t", "colab_type": "code", "outputId": "564f0024-0b7d-4d65-b588-fe2bdf8f5992", "colab": { "base_uri": "https://localhost:8080/", "height": 525 } }, "source": [ "def load_mnist():\n", " from tensorflow.examples.tutorials.mnist import input_data\n", "\n", " mnist = input_data.read_data_sets(\"MNIST_data/\", one_hot=False)\n", "\n", " X_train = np.vstack([img.reshape(-1,) for img in mnist.train.images])\n", " y_train = mnist.train.labels\n", "\n", " X_test = np.vstack([img.reshape(-1,) for img in mnist.test.images])\n", " y_test = mnist.test.labels\n", "\n", " del mnist\n", " return X_train, y_train.reshape(-1, 1), X_test, y_test.reshape(-1, 1)\n", "\n", "train_images, train_labels, test_images, test_labels = load_mnist()" ], "execution_count": 3, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "
\n",
"The default version of TensorFlow in Colab will soon switch to TensorFlow 2.x.
\n",
"We recommend you upgrade now \n",
"or ensure your notebook will continue to use TensorFlow 1.x via the %tensorflow_version 1.x
magic:\n",
"more info.