In [8]:
import jax.numpy as jnp
from jax import grad
#from jax import jit 
import numpy as np

In [9]:

# define soft argmax
def soft_argmax(x):
 y = jnp.exp(x)
 return y/jnp.sum(y)

# Specify the how to caculate the 

def predict(params, inputs):
 W = params[0]
 b = params[1]
 scores = jnp.dot( W, inputs) + b
 outputs = soft_argmax(scores)
 return outputs

def loss_fun(params, inputs, targets):
 preds = predict(params, inputs)
 # You can use logistic loss instead
 return jnp.sum((preds - targets)**2) 

grad_fun = grad(loss_fun) # gradient evaluation function

# You can even use Just In Time (JIT) compilation to optimize the implementation if you want
# grad_fun = jit(grad(loss_fun))
 

In [10]:
# Lets' cook up some input for this 
W = np.random.randn(3,4)
b = np.random.randn(3,1)
inputs = np.random.randn(4,1)
targets = np.array([1,0,0])
params=[W,b]

In [11]:
params

[array([[ 0.77956078, 3.37463341, 0.40595436, 0.20524409],
 [-0.32273478, 0.23802133, -0.32866143, -0.03417101],
 [-0.27570044, 0.6516192 , -0.35532861, -0.87215494]]),
 array([[1.49792127],
 [0.95496653],
 [0.18057272]])]

In [12]:
# Generate a prediction for this inputs
predict(params, inputs)

DeviceArray([[0.9819335 ],
 [0.01561047],
 [0.00245601]], dtype=float32)

In [13]:
# Compute the gradient from this imput
grad_fun(params, inputs,targets)

[DeviceArray([[ 0.13573891, 0.02055991, 0.1661767 , 0.14467116],
 [-0.1170654 , -0.01773149, -0.14331588, -0.12476883],
 [-0.01867339, -0.00282839, -0.02286067, -0.01990218]], dtype=float32),
 DeviceArray([[ 0.10304634],
 [-0.08887032],
 [-0.01417592]], dtype=float32)]

In [25]:

# What if we want to do a neural network?

def predict_nn(params, inputs):
 
 W1 = params[0]
 b1 = params[1]
 W2 = params[2]
 b2 = params[3]
 hidden = jnp.dot( W1, inputs) + b1
 activated = jnp.tanh(hidden)
 scores = jnp.dot(W2, activated) +b2
 
 outputs = soft_argmax(scores)
 return outputs

def loss_fun_nn(params, inputs, targets):
 preds = predict_nn(params, inputs)
 # You can use logistic loss instead
 return jnp.sum((preds - targets)**2) 

grad_fun_nn = grad(loss_fun_nn) # gradient evaluation function



In [26]:
# Lets' cook up some input for this 
W1 = np.random.randn(6,4)
b1 = np.random.randn(6,1)
W2 = np.random.randn(3,6)
b2 = np.random.randn(3,1)
params=[W1,b1,W2,b2]

inputs = np.random.randn(4,1)
targets = np.array([1,0,0])

In [27]:
# Generate a prediction for this inputs
predict_nn(params, inputs)

DeviceArray([[0.10044566],
 [0.77999187],
 [0.11956251]], dtype=float32)

In [28]:
# Compute the gradient from this imput
grad_fun_nn(params, inputs,targets)

[DeviceArray([[ 2.6204654e-05, 4.7051930e-04, 3.3387411e-04,
 3.4875007e-05],
 [-2.9092113e-04, -5.2236523e-03, -3.7066326e-03,
 -3.8717841e-04],
 [ 1.9083241e-02, 3.4265032e-01, 2.4314001e-01,
 2.5397327e-02],
 [ 2.7686538e-04, 4.9712737e-03, 3.5275482e-03,
 3.6847201e-04],
 [ 5.0642653e-03, 9.0931728e-02, 6.4523920e-02,
 6.7398823e-03],
 [-8.7190457e-03, -1.5655535e-01, -1.1108955e-01,
 -1.1603922e-02]], dtype=float32),
 DeviceArray([[-0.00071549],
 [ 0.00794324],
 [-0.5210443 ],
 [-0.00755947],
 [-0.1382735 ],
 [ 0.23806275]], dtype=float32),
 DeviceArray([[-0.32022256, -0.3198762 , 0.2565086 , -0.31863633,
 0.19721033, 0.22003594],
 [ 0.68770146, 0.6869576 , -0.550871 , 0.68429494,
 -0.42352366, -0.47254333],
 [-0.36747894, -0.36708143, 0.2943625 , -0.3656586 ,
 0.22631335, 0.25250742]], dtype=float32),
 DeviceArray([[-0.32081914],
 [ 0.68898267],
 [-0.36816356]], dtype=float32)]