{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "DVimr8MhxmRJ" }, "outputs": [], "source": [ "from collections import Counter\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F" ] }, { "cell_type": "markdown", "metadata": { "id": "T1pCJFXJxmRJ" }, "source": [ "### I. Representing Text\n", "In NLP, we are dealing with words and phrases which are discrete features. How do we represent them in a way that a neural network can easily process?" ] }, { "cell_type": "markdown", "metadata": { "id": "15rh0T_7xmRK" }, "source": [ "#### 1. Bag of Words\n", "Bag-of-words is a conventional way to represent documents before deep learning. The idea is to represent each document as the count / frequency of each word in the vocabulary.\n", "\n", "Search for TF-IDF and n-gram models if you want to know more." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "doQmD5txxmRL", "outputId": "dcbc1a41-1321-4521-836a-dd5dd09f7314" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "['the', 'quick', 'brown', 'fox', 'jumps', 'over', 'the', 'lazy', 'dog', '.']\n" ] } ], "source": [ "sentence = \"the quick brown fox jumps over the lazy dog .\"\n", "# here we are assuming tokens are space-separated\n", "# for real applications, we often use more complex tokenizers to split the raw text\n", "tokens = sentence.split()\n", "print(tokens)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "MQTm7zdexmRL", "outputId": "85c83a6d-283b-4adf-b655-293ccd0f48ed" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "{'fox': 0, 'quick': 1, 'the': 2, 'lazy': 3, '.': 4, 'dog': 5, 'jumps': 6, 'over': 7, 'brown': 8}\n" ] } ], "source": [ "idx2token = list(set(token for token in tokens))\n", "token2idx = dict((t, i) for (i, t) in enumerate(idx2token))\n", "vocab_size = len(idx2token)\n", "print(token2idx)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "l-Hi8KvuxmRL", "outputId": "eea9717e-07e4-449a-fc37-8f8c923c9be7" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "tensor([1, 1, 2, 1, 1, 1, 1, 1, 1])\n" ] } ], "source": [ "counter = Counter(tokens)\n", "bow = torch.tensor([counter[token] for token in idx2token])\n", "print(bow)" ] }, { "cell_type": "markdown", "metadata": { "id": "WDHZKwvsxmRL" }, "source": [ "The drawbacks of a BoW include:\n", "- the representation of a document is huge with a large vocabulary\n", "- each document is a single count vector which is limited in many applications\n", "- words are treated as independent" ] }, { "cell_type": "markdown", "metadata": { "id": "T7xDNWGixmRL" }, "source": [ "#### 2. One Hot Encoding\n", "Another approach closely related to BoW is one hot encoding. It represents each word as an indicator vector." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "hxCg8yHgxmRL", "outputId": "98cb900d-3a84-445a-a666-9ff4b56b8d5e" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "tensor([[0, 0, 1, 0, 0, 0, 0, 0, 0],\n", " [0, 1, 0, 0, 0, 0, 0, 0, 0],\n", " [0, 0, 0, 0, 0, 0, 0, 0, 1],\n", " [1, 0, 0, 0, 0, 0, 0, 0, 0],\n", " [0, 0, 0, 0, 0, 0, 1, 0, 0],\n", " [0, 0, 0, 0, 0, 0, 0, 1, 0],\n", " [0, 0, 1, 0, 0, 0, 0, 0, 0],\n", " [0, 0, 0, 1, 0, 0, 0, 0, 0],\n", " [0, 0, 0, 0, 0, 1, 0, 0, 0],\n", " [0, 0, 0, 0, 1, 0, 0, 0, 0]])\n" ] } ], "source": [ "token_ids = torch.tensor([token2idx[token] for token in tokens])\n", "print(F.one_hot(token_ids, num_classes=vocab_size))" ] }, { "cell_type": "markdown", "metadata": { "id": "JeqokaKwxmRL" }, "source": [ "Similar to BoW, the drawbacks are:\n", "- with a large vocabulary, the representation is huge.\n", "- it treats all words as independent with no relations to each other." ] }, { "cell_type": "markdown", "metadata": { "id": "mTF6H7MjxmRM" }, "source": [ "#### 3. Dense Word Embeddings\n", "Word embeddings keep a look up table for each word in the vocabulary. Each sentence is represented as a sequence of word embeddings.\n", "\n", "The embedding parameters are adjusted together with all other model parameters during training. Or (less common) you can fix the parameters to some pretrained values (e.g. word2vec)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "plicRtrexmRM", "outputId": "fc8f92a5-df8d-4f14-d57b-f69236cc4509" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "torch.Size([10, 8])\n", "tensor([[ 1.6840, -0.0304, -0.3668, 0.6942, -0.2271, 0.9454, 0.3811, -1.2579],\n", " [-0.7907, -0.6264, 0.4665, 1.9389, 0.8427, 0.8852, -1.2025, -0.0106],\n", " [ 1.9983, 0.2476, 0.7603, -1.0315, 0.4166, -0.5394, 1.6762, -1.2370],\n", " [-0.4871, 1.2124, 0.5749, 0.5276, 0.9060, 0.2522, -0.1403, 0.4153],\n", " [ 0.6686, -1.2796, 0.2280, 1.4104, 0.7402, -0.2454, 0.5503, 0.0655],\n", " [-0.4550, 0.1212, -0.0990, 0.9856, -0.4488, 0.0389, 0.2322, 0.0431],\n", " [ 1.6840, -0.0304, -0.3668, 0.6942, -0.2271, 0.9454, 0.3811, -1.2579],\n", " [ 0.9182, 0.9667, 0.3924, 0.6065, 1.6190, 0.2829, -2.7679, -0.8704],\n", " [ 0.1241, 0.2837, -0.9605, -1.1846, 0.1741, 0.9107, -1.0985, 2.3858],\n", " [ 1.6951, 0.6644, -0.3910, -0.7861, 0.3688, 0.4912, 0.2069, -0.2365]],\n", " grad_fn=)\n" ] } ], "source": [ "embed_size = 8\n", "embeds = nn.Embedding(vocab_size, embed_size)\n", "embeddings = embeds(token_ids)\n", "# (seqence_length, embedding_size)\n", "print(embeddings.shape)\n", "print(embeddings)" ] }, { "cell_type": "markdown", "metadata": { "id": "ZBxR4hlDxmRM" }, "source": [ "#### 4. Representing Batches of Sentences\n", "Sentences are not guaranteed to have the same number of words, how to represent them as a batch?" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "dne5cuRIxmRM", "outputId": "60688663-f599-4c89-aac4-d2c0ecafd580" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "{'': 0, '': 1, 'fox': 2, 'quick': 3, 'the': 4, 'lazy': 5, '.': 6, 'dog': 7, 'jumps': 8, 'over': 9, 'brown': 10}\n" ] } ], "source": [ "idx2token = ['', ''] + idx2token\n", "token2idx = dict((t, i) for (i, t) in enumerate(idx2token))\n", "print(token2idx)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "nhND9YkYxmRM", "outputId": "ea8b6645-8d1e-44e4-f9ca-8864ab14e32c" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "[['the', 'quick', 'brown', 'fox', 'jumps', 'over', 'the', 'lazy', 'dog', '.'], ['the', 'dog', 'is', 'lying', 'beside', 'the', 'fox']]\n" ] } ], "source": [ "sentences = [\n", " \"the quick brown fox jumps over the lazy dog .\",\n", " \"the dog is lying beside the fox\",\n", " ]\n", "\n", "# For each word in the sentence, it has a part-of-speech (POS) category.\n", "# For the meaning of the common POS label, we can refer to https://www.sketchengine.eu/penn-treebank-tagset/\n", "pos_tags = [\n", " \"DT ADJ ADJ NN VB IN DT ADJ NN .\",\n", " \"DT NN VB VB IN DT NN\"\n", "]\n", "batch_tokens = [sent.split() for sent in sentences]\n", "batch_labels = [tags.split() for tags in pos_tags]\n", "print(batch_tokens)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Ba24efJoxmRM" }, "outputs": [], "source": [ "token_ids_0 = [token2idx[token] for token in batch_tokens[0]]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 258 }, "id": "ztfdyZFdxmRM", "outputId": "ff6e5ea1-d374-4019-ba0c-f216f5c80b27" }, "outputs": [ { "output_type": "error", "ename": "KeyError", "evalue": "ignored", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# there might be words not present in the vocabulary\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mtoken_ids_1\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mtoken2idx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtoken\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mtoken\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mbatch_tokens\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# there might be words not present in the vocabulary\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mtoken_ids_1\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mtoken2idx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtoken\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mtoken\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mbatch_tokens\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;31mKeyError\u001b[0m: 'is'" ] } ], "source": [ "# there might be words not present in the vocabulary\n", "token_ids_1 = [token2idx[token] for token in batch_tokens[1]]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "q-_RvZkXxmRN", "outputId": "29c248bd-70aa-481c-b26a-bdbbce739415" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "[4, 7, 1, 1, 1, 4, 2]\n" ] } ], "source": [ "# default unseen words to the token\n", "token_ids_1 = [token2idx.get(token, token2idx['']) for token in batch_tokens[1]]\n", "print(token_ids_1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "9_fCW8KIxmRN", "outputId": "2ae7fa9e-0657-43fa-c360-50e671c87c2b" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "tensor([[ 4, 3, 10, 2, 8, 9, 4, 5, 7, 6],\n", " [ 4, 7, 1, 1, 1, 4, 2, 0, 0, 0]])\n" ] } ], "source": [ "max_length = max([len(ids) for ids in [token_ids_0, token_ids_1]])\n", "batch_ids = torch.tensor(\n", " [\n", " token_ids_0 + [token2idx['']] * (max_length - len(token_ids_0)),\n", " token_ids_1 + [token2idx['']] * (max_length - len(token_ids_1)),\n", " ]\n", ")\n", "print(batch_ids)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "NXJsjYSVxmRN", "outputId": "dbbee447-fc73-4c75-b9c1-dc0d3beec513" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "torch.Size([2, 10, 8])\n", "tensor([[[ 7.8184e-01, -1.3621e+00, 4.2569e-01, -1.8461e+00, 2.1676e+00,\n", " -9.1152e-01, -9.6761e-01, 2.3834e-01],\n", " [-6.3344e-01, -4.9714e-01, -4.8708e-01, 6.8581e-01, -1.7923e-01,\n", " 8.5219e-02, 8.8637e-01, -4.0823e-01],\n", " [-9.2896e-01, -2.6551e-01, -1.7135e+00, -3.6573e-01, -1.0628e+00,\n", " -1.8727e-01, 1.6141e+00, 1.0059e+00],\n", " [ 9.2869e-01, -1.9563e+00, 5.5091e-01, -3.0735e-01, -1.6190e+00,\n", " 1.5873e-01, 1.1222e+00, -1.1373e+00],\n", " [ 2.2582e+00, -2.5128e+00, 6.3548e-01, 1.7982e-01, -3.6781e-01,\n", " -6.6681e-01, 9.5394e-01, 1.4822e-01],\n", " [ 6.1321e-01, -1.0435e-01, -2.0886e+00, -1.2996e-01, -4.2351e-02,\n", " 8.4045e-01, -1.5285e+00, 8.9550e-02],\n", " [ 7.8184e-01, -1.3621e+00, 4.2569e-01, -1.8461e+00, 2.1676e+00,\n", " -9.1152e-01, -9.6761e-01, 2.3834e-01],\n", " [ 1.4870e-01, 1.0553e+00, -1.8828e+00, -1.0740e+00, -4.0737e-01,\n", " -7.8651e-01, 1.4763e+00, 4.4058e-01],\n", " [-1.0930e+00, -9.9038e-01, -9.3767e-01, -6.1010e-01, -4.0999e-01,\n", " 8.9367e-02, 1.3753e+00, 9.0367e-01],\n", " [ 1.5305e+00, 4.4224e-01, 5.6699e-01, 2.0215e+00, -1.2492e+00,\n", " 4.2191e-01, 8.5750e-01, -9.0576e-02]],\n", "\n", " [[ 7.8184e-01, -1.3621e+00, 4.2569e-01, -1.8461e+00, 2.1676e+00,\n", " -9.1152e-01, -9.6761e-01, 2.3834e-01],\n", " [-1.0930e+00, -9.9038e-01, -9.3767e-01, -6.1010e-01, -4.0999e-01,\n", " 8.9367e-02, 1.3753e+00, 9.0367e-01],\n", " [ 2.3303e+00, 4.0642e-01, -5.2861e-01, -8.1552e-01, -9.0573e-01,\n", " 1.1469e-03, 1.3433e+00, -5.4895e-01],\n", " [ 2.3303e+00, 4.0642e-01, -5.2861e-01, -8.1552e-01, -9.0573e-01,\n", " 1.1469e-03, 1.3433e+00, -5.4895e-01],\n", " [ 2.3303e+00, 4.0642e-01, -5.2861e-01, -8.1552e-01, -9.0573e-01,\n", " 1.1469e-03, 1.3433e+00, -5.4895e-01],\n", " [ 7.8184e-01, -1.3621e+00, 4.2569e-01, -1.8461e+00, 2.1676e+00,\n", " -9.1152e-01, -9.6761e-01, 2.3834e-01],\n", " [ 9.2869e-01, -1.9563e+00, 5.5091e-01, -3.0735e-01, -1.6190e+00,\n", " 1.5873e-01, 1.1222e+00, -1.1373e+00],\n", " [-3.3285e-02, -1.1480e-01, -2.4501e-01, 6.4260e-01, 2.3863e-01,\n", " -5.5390e-01, 7.9101e-01, 4.2950e-01],\n", " [-3.3285e-02, -1.1480e-01, -2.4501e-01, 6.4260e-01, 2.3863e-01,\n", " -5.5390e-01, 7.9101e-01, 4.2950e-01],\n", " [-3.3285e-02, -1.1480e-01, -2.4501e-01, 6.4260e-01, 2.3863e-01,\n", " -5.5390e-01, 7.9101e-01, 4.2950e-01]]],\n", " grad_fn=)\n" ] } ], "source": [ "embed_size = 8\n", "embeds = nn.Embedding(len(idx2token), embed_size)\n", "batch_embeds = embeds(batch_ids)\n", "print(batch_embeds.shape)\n", "print(batch_embeds)" ] }, { "cell_type": "markdown", "metadata": { "id": "wTId660txmRN" }, "source": [ "In practice, there are two padding strategies:\n", "- pad or truncate all sentences to the same length\n", "- dynamically pad to the maximum length within a batch\n", "\n", "The first strategy is easier to implement and suitable when the variation of lengths is small within the data set.\n", "\n", "The second is more efficient and can be faster in training if the sentence lengths varies a lot. This strategy can be further optimized by constructing batches with similar lengths." ] }, { "cell_type": "markdown", "metadata": { "id": "Hv1Uz-rGxmRN" }, "source": [ "### II. LSTM / Transformer\n", "Given the embeddings, we can feed them into an LSTM or transformer model." ] }, { "cell_type": "markdown", "source": [ "For LSTM, the input is a tensor of size (batch_size, seq_length, input_dim) and the optional hidden state and cell state. The default value for hidden state and cell state is zero.\n", "\n", "doc: https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html\n" ], "metadata": { "id": "375Hxb7s0BnQ" } }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "v2Lm8LuTxmRN" }, "outputs": [], "source": [ "hidden_size = 16\n", "lstm = nn.LSTM(embed_size, hidden_size, batch_first=True)\n", "output, (hn, cn) = lstm(batch_embeds)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "mYIyTCD8xmRN", "outputId": "a0c37f22-25de-4b4c-a3fc-4d8e75c1e6e2" }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "torch.Size([2, 10, 16])" ] }, "metadata": {}, "execution_count": 31 } ], "source": [ "# output is the output from the last layer\n", "# (batch_size, seq_length, hidden_size)\n", "output.shape" ] }, { "cell_type": "code", "source": [ "# For bi-LSTM, we only need to set bidirectional=True.\n", "bilstm = nn.LSTM(embed_size, hidden_size, batch_first=True, bidirectional=True)\n", "bioutput, (hn, cn) = bilstm(batch_embeds)\n", "\n", "# the last dim of the output will be the concatenation of the forward and reverse hidden states at each time step\n", "# (batch_size, seq_length, 2 * hidden_size)\n", "bioutput.shape" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "N92wPCn51QEq", "outputId": "212fba7c-ec7b-4b44-ff6b-3a3d6fa521eb" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "torch.Size([2, 10, 32])" ] }, "metadata": {}, "execution_count": 43 } ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "fIIbnxu1xmRO", "outputId": "6213b050-f9b1-4194-b099-774d6b0d620f" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "torch.Size([2, 10, 4])\n" ] } ], "source": [ "# for sequence tagging, make one prediction for each token\n", "# suppose we are doing sequence labeling task with 4 categories\n", "num_classes = 4\n", "linear = nn.Linear(hidden_size, num_classes)\n", "\n", "# (batch_size, seq_length, num_classes)\n", "logits = linear(output)\n", "print(logits.shape)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "FkC2ewJexmRO", "outputId": "aea2251b-2661-4e2a-bf6e-95ce240fbf7c" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "torch.Size([2, 4])\n" ] } ], "source": [ "# for sequence classification, aggregate the sequence dimension first\n", "# (batch_size, num_classes)\n", "logits = linear(output.mean(dim=1))\n", "print(logits.shape)" ] }, { "cell_type": "markdown", "source": [ "For transformer, PyTorch provides several modules:\n", "\n", "* nn.Transformer: encoder + decoder\n", "* nn.TransformerEncoder: a stack of N encoder layers\n", "* nn.TransformerDecoder: a stack of N decoder layers\n", "* nn.TransformerEncoderLayer: self-attn and feedforward network.\n", "* nn.TransformerDecoderLayer: self-attn, multi-head-attn (for encoder-decoder) and feedforward network.\n", "\n", "doc: https://pytorch.org/docs/stable/nn.html#transformer-layers" ], "metadata": { "id": "LoxmQS1A2WCs" } }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "U7U7kfwtxmRO" }, "outputs": [], "source": [ "num_heads = 4\n", "hidden_size = 16\n", "dropout = 0.1\n", "layers = nn.TransformerEncoderLayer(\n", " embed_size, num_heads, hidden_size, dropout, batch_first=True)\n", "num_layers = 2\n", "transformer = nn.TransformerEncoder(layers, num_layers)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "FWPoHinYxmRO", "outputId": "40764c3d-5224-482d-e040-115148271420" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "torch.Size([2, 10, 8])\n" ] } ], "source": [ "transformer.eval()\n", "output = transformer(batch_embeds)\n", "print(output.shape)" ] }, { "cell_type": "markdown", "metadata": { "id": "prFIxcrLxmRO" }, "source": [ "For Transformer models, if the input includes padded tokens, we need to mask those parts out to get an accurate output that only depends on the unpadded tokens." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "qZA2lbSixmRO", "outputId": "9b53b3ac-64c7-4f07-adb6-c5306969f5dc" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "tensor([[False, False, False, False, False, False, False, False, False, False],\n", " [False, False, False, False, False, False, False, True, True, True]])\n" ] } ], "source": [ "padding_mask = (batch_ids == token2idx[''])\n", "print(padding_mask)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "MqMakzqVxmRO", "outputId": "fdba2dff-3a36-49a8-f0ff-4587a0b43c6b" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "torch.Size([2, 10, 8])\n" ] } ], "source": [ "output_masked = transformer(\n", " batch_embeds, \n", " src_key_padding_mask=padding_mask\n", ")\n", "print(output_masked.shape)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "fQO_KMPGxmRO", "outputId": "d3bb09de-6cc0-4716-8515-f0e9edb6b1c1" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "tensor(True)\n", "tensor(False)\n" ] } ], "source": [ "is_equal = (output == output_masked)\n", "# no padding for the first example\n", "print(is_equal[0].all())\n", "# with padding for the second example\n", "# providing the padding mask leads to completely different output\n", "print(is_equal[1].any())" ] }, { "cell_type": "markdown", "metadata": { "id": "_DfIYt62xmRO" }, "source": [ "Note the difference between `mask` and `src_key_padding_mask` when using `nn.TransformerEncoder`.\n", "\n", "`src_key_padding_mask` is used as illustrated above. The mask is of shape `(batch_size, sequence_length)` indicating whether the input token is padded.\n", "\n", "`mask` is completely different. It is used to prevent looking into future tokens in sequence generation tasks. The shape is `(seq_length, seq_length)`." ] }, { "cell_type": "markdown", "metadata": { "id": "Cn6J6NGQxmRO" }, "source": [ "If we are dealing with a sequence tagging task, we also need to pad the labels." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XaOeIfy1xmRO" }, "outputs": [], "source": [ "idx2tag = [\"\", \"DT\", \"NN\", \"ADJ\", \"VB\", \"IN\", \".\"]\n", "tag2idx = dict((t, i) for (i, t) in enumerate(idx2tag))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "5qboJyCJxmRO", "outputId": "589f9645-657c-43c4-e012-09a0fb613ae0" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "tensor([[1, 3, 3, 2, 4, 5, 1, 3, 2, 6],\n", " [1, 2, 4, 4, 5, 1, 2, 0, 0, 0]])\n" ] } ], "source": [ "max_length = max([len(labels) for labels in batch_labels])\n", "batch_targets = torch.tensor(\n", " [\n", " [tag2idx[tag] for tag in labels] + [tag2idx[\"\"]] * (max_length - len(labels)) for labels in batch_labels\n", " ]\n", ")\n", "print(batch_targets)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0WR55HptxmRP" }, "outputs": [], "source": [ "num_classes = len(idx2tag)\n", "linear = nn.Linear(embed_size, num_classes)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "TYCamcY2xmRP", "outputId": "e3025b5f-f4c8-4fa0-cb0d-a0ec3a7d7cbf" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "torch.Size([2, 10, 7])\n" ] } ], "source": [ "logits = linear(output_masked)\n", "print(logits.shape)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "BFTOgLEjxmRP", "outputId": "58bf3ff4-41a4-41d8-c79d-b104ec11a4da" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "tensor(1.9610, grad_fn=)\n" ] } ], "source": [ "loss = F.cross_entropy(logits.view(-1, num_classes), batch_targets.view(-1))\n", "print(loss)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "qFY3qLnQxmRP", "outputId": "d8423be5-c6c6-4505-d8c2-f1aab572d4a1" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "tensor(1.8112, grad_fn=)\n" ] } ], "source": [ "# loss calculated from padded labels are ignored\n", "# we do not care what predictions are made on padded tokens because we will discard them anyway.\n", "loss_ignore_padding = F.cross_entropy(logits.view(-1, num_classes), batch_targets.view(-1), ignore_index=tag2idx[\"\"])\n", "print(loss_ignore_padding)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "eFMG3aAtxmRP", "outputId": "bf941f68-35d6-44d1-fb64-c1ab90392fe9" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "tensor(1.8112, grad_fn=)\n" ] } ], "source": [ "# we can achive the same thing by masking out the padded tokens\n", "active_mask = torch.logical_not(padding_mask)\n", "loss_ignore_padding_v2 = F.cross_entropy(logits[active_mask], batch_targets[active_mask])\n", "print(loss_ignore_padding_v2)" ] }, { "cell_type": "markdown", "source": [ "### Other resources\n", "\n", "Classifying Names with a Character-Level RNN: https://pytorch.org/tutorials/intermediate/char_rnn_generation_tutorial.html\n", "\n", "Translation with a Sequence to Sequence Network and Attention: https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html\n", "\n", "LSTM (character + word) POS-tag model PyTorch: https://www.kaggle.com/code/krishanudb/lstm-character-word-pos-tag-model-pytorch\n", "\n", "PyTorch POS Tagging: https://github.com/bentrevett/pytorch-pos-tagging" ], "metadata": { "id": "GCZX28jo3qpn" } } ], "metadata": { "interpreter": { "hash": "63a9c5d85f6e23ec927dc7ea406bd3e633b330764de0c8fed29765d0f6a2a64b" }, "kernelspec": { "display_name": "Python 3.9.7 ('torch')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" }, "orig_nbformat": 4, "colab": { "provenance": [], "collapsed_sections": [ "15rh0T_7xmRK", "T7xDNWGixmRL", "mTF6H7MjxmRM", "ZBxR4hlDxmRM" ] } }, "nbformat": 4, "nbformat_minor": 0 }