{ "cells": [ { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "import jax.numpy as jnp\n", "from jax import grad\n", "#from jax import jit \n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "\n", "# define soft argmax\n", "def soft_argmax(x):\n", " y = jnp.exp(x)\n", " return y/jnp.sum(y)\n", "\n", "# Specify the how to caculate the \n", "\n", "def predict(params, inputs):\n", " W = params[0]\n", " b = params[1]\n", " scores = jnp.dot( W, inputs) + b\n", " outputs = soft_argmax(scores)\n", " return outputs\n", "\n", "def loss_fun(params, inputs, targets):\n", " preds = predict(params, inputs)\n", " # You can use logistic loss instead\n", " return jnp.sum((preds - targets)**2) \n", "\n", "grad_fun = grad(loss_fun) # gradient evaluation function\n", "\n", "# You can even use Just In Time (JIT) compilation to optimize the implementation if you want\n", "# grad_fun = jit(grad(loss_fun))\n", " " ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "# Lets' cook up some input for this \n", "W = np.random.randn(3,4)\n", "b = np.random.randn(3,1)\n", "inputs = np.random.randn(4,1)\n", "targets = np.array([1,0,0])\n", "params=[W,b]" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[array([[ 0.77956078, 3.37463341, 0.40595436, 0.20524409],\n", " [-0.32273478, 0.23802133, -0.32866143, -0.03417101],\n", " [-0.27570044, 0.6516192 , -0.35532861, -0.87215494]]),\n", " array([[1.49792127],\n", " [0.95496653],\n", " [0.18057272]])]" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "params" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DeviceArray([[0.9819335 ],\n", " [0.01561047],\n", " [0.00245601]], dtype=float32)" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Generate a prediction for this inputs\n", "predict(params, inputs)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[DeviceArray([[ 0.13573891, 0.02055991, 0.1661767 , 0.14467116],\n", " [-0.1170654 , -0.01773149, -0.14331588, -0.12476883],\n", " [-0.01867339, -0.00282839, -0.02286067, -0.01990218]], dtype=float32),\n", " DeviceArray([[ 0.10304634],\n", " [-0.08887032],\n", " [-0.01417592]], dtype=float32)]" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Compute the gradient from this imput\n", "grad_fun(params, inputs,targets)" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "\n", "# What if we want to do a neural network?\n", "\n", "def predict_nn(params, inputs):\n", " \n", " W1 = params[0]\n", " b1 = params[1]\n", " W2 = params[2]\n", " b2 = params[3]\n", " hidden = jnp.dot( W1, inputs) + b1\n", " activated = jnp.tanh(hidden)\n", " scores = jnp.dot(W2, activated) +b2\n", " \n", " outputs = soft_argmax(scores)\n", " return outputs\n", "\n", "def loss_fun_nn(params, inputs, targets):\n", " preds = predict_nn(params, inputs)\n", " # You can use logistic loss instead\n", " return jnp.sum((preds - targets)**2) \n", "\n", "grad_fun_nn = grad(loss_fun_nn) # gradient evaluation function\n", "\n" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "# Lets' cook up some input for this \n", "W1 = np.random.randn(6,4)\n", "b1 = np.random.randn(6,1)\n", "W2 = np.random.randn(3,6)\n", "b2 = np.random.randn(3,1)\n", "params=[W1,b1,W2,b2]\n", "\n", "inputs = np.random.randn(4,1)\n", "targets = np.array([1,0,0])" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DeviceArray([[0.10044566],\n", " [0.77999187],\n", " [0.11956251]], dtype=float32)" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Generate a prediction for this inputs\n", "predict_nn(params, inputs)" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[DeviceArray([[ 2.6204654e-05, 4.7051930e-04, 3.3387411e-04,\n", " 3.4875007e-05],\n", " [-2.9092113e-04, -5.2236523e-03, -3.7066326e-03,\n", " -3.8717841e-04],\n", " [ 1.9083241e-02, 3.4265032e-01, 2.4314001e-01,\n", " 2.5397327e-02],\n", " [ 2.7686538e-04, 4.9712737e-03, 3.5275482e-03,\n", " 3.6847201e-04],\n", " [ 5.0642653e-03, 9.0931728e-02, 6.4523920e-02,\n", " 6.7398823e-03],\n", " [-8.7190457e-03, -1.5655535e-01, -1.1108955e-01,\n", " -1.1603922e-02]], dtype=float32),\n", " DeviceArray([[-0.00071549],\n", " [ 0.00794324],\n", " [-0.5210443 ],\n", " [-0.00755947],\n", " [-0.1382735 ],\n", " [ 0.23806275]], dtype=float32),\n", " DeviceArray([[-0.32022256, -0.3198762 , 0.2565086 , -0.31863633,\n", " 0.19721033, 0.22003594],\n", " [ 0.68770146, 0.6869576 , -0.550871 , 0.68429494,\n", " -0.42352366, -0.47254333],\n", " [-0.36747894, -0.36708143, 0.2943625 , -0.3656586 ,\n", " 0.22631335, 0.25250742]], dtype=float32),\n", " DeviceArray([[-0.32081914],\n", " [ 0.68898267],\n", " [-0.36816356]], dtype=float32)]" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Compute the gradient from this imput\n", "grad_fun_nn(params, inputs,targets)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 4 }