2022-02-24 01:24:20 +09:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "4c31f5ad",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'C:\\\\Users\\\\Monoid\\\\anaconda3\\\\envs\\\\nn\\\\python.exe'"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import sys\n",
"sys.executable"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "2b9e11e7",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"load bert tokenizer...\n"
]
}
],
"source": [
"from transformers import BertTokenizer\n",
"print(\"load bert tokenizer...\")\n",
"PRETAINED_MODEL_NAME = 'bert-base-multilingual-cased'\n",
"tokenizer = BertTokenizer.from_pretrained(PRETAINED_MODEL_NAME)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "82bf44a2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"cuda available : True\n",
"available device count : 1\n",
"device name: NVIDIA GeForce RTX 3070\n"
]
}
],
"source": [
"import torch\n",
"print(\"cuda available :\",torch.cuda.is_available())\n",
"print(\"available device count :\",torch.cuda.device_count())\n",
"if torch.cuda.is_available():\n",
" device_index = torch.cuda.current_device()\n",
" print(\"device name:\",torch.cuda.get_device_name(device_index))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "38dcf62d",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"read train set\n",
2022-02-27 19:50:13 +09:00
"100%|██████████████████████████████████████████████████████████████████████| 150000/150000 [00:00<00:00, 208913.67it/s]\n",
2022-02-24 01:24:20 +09:00
"read test set\n",
2022-02-27 19:50:13 +09:00
"100%|████████████████████████████████████████████████████████████████████████| 50000/50000 [00:00<00:00, 259067.57it/s]\n"
2022-02-24 01:24:20 +09:00
]
}
],
"source": [
"from ndataset import readNsmcDataAll, make_collate_fn\n",
2022-02-27 19:50:13 +09:00
"dataTrain, dataDev, dataTest = readNsmcDataAll()\n",
2022-02-24 01:24:20 +09:00
"collate_fn = make_collate_fn(tokenizer)"
]
},
{
"cell_type": "code",
2022-02-27 19:50:13 +09:00
"execution_count": 25,
2022-02-24 01:24:20 +09:00
"id": "650c8a19",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2022-02-27 19:50:13 +09:00
"Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias']\n",
2022-02-24 01:24:20 +09:00
"- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
]
}
],
"source": [
"from transformers import BertModel\n",
"PRETAINED_MODEL_NAME = 'bert-base-multilingual-cased'\n",
"bert = BertModel.from_pretrained(PRETAINED_MODEL_NAME)"
]
},
{
"cell_type": "markdown",
"id": "30d69b45",
"metadata": {},
"source": [
"BERT 로딩"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "7583b0d1",
"metadata": {},
"outputs": [],
"source": [
"import torch.nn as nn\n",
"\n",
"class MyModel(nn.Module):\n",
" def __init__(self,bert):\n",
" super().__init__()\n",
" self.bert = bert\n",
" self.dropout = nn.Dropout(p=0.1)\n",
2022-02-27 19:50:13 +09:00
" self.lin1 = nn.Linear(768,768) #[batch_size,768] -> [batch_size,768]\n",
" self.gelu = nn.GELU()\n",
" self.lin2 = nn.Linear(768,1) #[batch_size,768] -> [batch_size,1]\n",
2022-02-24 01:24:20 +09:00
"\n",
" def forward(self,**kargs):\n",
" emb = self.bert(**kargs)\n",
" e1 = self.dropout(emb['pooler_output'])\n",
2022-02-27 19:50:13 +09:00
" e2 = self.gelu(self.lin1(e1))\n",
2022-02-24 01:24:20 +09:00
" w = self.lin2(e2)\n",
" return w.squeeze() #[batch_size]"
]
},
{
"cell_type": "markdown",
"id": "befe62b0",
"metadata": {},
"source": [
"모델 선언. 비슷하게 감."
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "36585e76",
"metadata": {},
"outputs": [],
"source": [
"model = MyModel(bert)"
]
},
2022-02-27 19:50:13 +09:00
{
"cell_type": "code",
"execution_count": 8,
"id": "0e36d9e2",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\"\\nclass MyModel(nn.Module):\\n def __init__(self,embeddings):\\n super().__init__()\\n self.embeddings = embeddings\\n self.lin0 = nn.Linear(768,1)\\n \\n def forward(self,**kargs):\\n emb = self.embeddings(kargs['input_ids']) #[batch_size,word,768]\\n e0 = emb * kargs['attention_mask'].unsqueeze(dim=-1)#[batch_size,word,768]\\n e1 = self.lin0(e0)#[batch_size,word,1]\\n w = (e1.sum(axis=1))#[batch_size,1]\\n return w.squeeze() #[batch_size]\\nmodel = MyModel(bert.embeddings)\\nfor param in model.embeddings.parameters():\\n param.requires_grad = False\\n\""
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"\"\"\"\n",
"class MyModel(nn.Module):\n",
" def __init__(self,embeddings):\n",
" super().__init__()\n",
" self.embeddings = embeddings\n",
" self.lin0 = nn.Linear(768,1)\n",
" \n",
" def forward(self,**kargs):\n",
" emb = self.embeddings(kargs['input_ids']) #[batch_size,word,768]\n",
" e0 = emb * kargs['attention_mask'].unsqueeze(dim=-1)#[batch_size,word,768]\n",
" e1 = self.lin0(e0)#[batch_size,word,1]\n",
" w = (e1.sum(axis=1))#[batch_size,1]\n",
" return w.squeeze() #[batch_size]\n",
"model = MyModel(bert.embeddings)\n",
"for param in model.embeddings.parameters():\n",
" param.requires_grad = False\n",
"\"\"\""
]
},
2022-02-24 01:24:20 +09:00
{
"cell_type": "markdown",
"id": "7969fead",
"metadata": {},
"source": [
2022-02-27 19:50:13 +09:00
"### 학습 과정에서 벌어지는 일"
2022-02-24 01:24:20 +09:00
]
},
{
"cell_type": "code",
2022-02-27 19:50:13 +09:00
"execution_count": 9,
2022-02-24 01:24:20 +09:00
"id": "8c2a4bc9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"MyModel(\n",
" (bert): BertModel(\n",
" (embeddings): BertEmbeddings(\n",
" (word_embeddings): Embedding(119547, 768, padding_idx=0)\n",
" (position_embeddings): Embedding(512, 768)\n",
" (token_type_embeddings): Embedding(2, 768)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (encoder): BertEncoder(\n",
" (layer): ModuleList(\n",
" (0): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (1): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (2): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (3): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (4): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (5): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (6): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (7): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (8): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (9): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (10): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (11): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" )\n",
" (pooler): BertPooler(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (activation): Tanh()\n",
" )\n",
" )\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
2022-02-27 19:50:13 +09:00
" (lin1): Linear(in_features=768, out_features=768, bias=True)\n",
" (gelu): GELU()\n",
" (lin2): Linear(in_features=768, out_features=1, bias=True)\n",
2022-02-24 01:24:20 +09:00
")"
]
},
2022-02-27 19:50:13 +09:00
"execution_count": 9,
2022-02-24 01:24:20 +09:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.cpu()"
]
},
{
"cell_type": "code",
"execution_count": 10,
2022-02-27 19:50:13 +09:00
"id": "def10f6c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'input_ids': tensor([[ 101, 9405, 62200, 14523, 48549, 119, 102, 0],\n",
" [ 101, 9294, 12424, 69592, 48549, 119, 102, 0],\n",
" [ 101, 9479, 68984, 48549, 119, 102, 0, 0],\n",
" [ 101, 9659, 22458, 119192, 12965, 48549, 119, 102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 0],\n",
" [1, 1, 1, 1, 1, 1, 1, 0],\n",
" [1, 1, 1, 1, 1, 1, 0, 0],\n",
" [1, 1, 1, 1, 1, 1, 1, 1]])}"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inputs = tokenizer([\"사랑해요.\",\"무서워요.\",\"슬퍼요.\",\"재미있어요.\"],\n",
" return_tensors = 'pt', padding='longest')\n",
"inputs"
]
},
{
"cell_type": "code",
"execution_count": 11,
2022-02-24 01:24:20 +09:00
"id": "e027b926",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([4, 768])"
]
},
2022-02-27 19:50:13 +09:00
"execution_count": 11,
2022-02-24 01:24:20 +09:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
2022-02-27 19:50:13 +09:00
"hidden = bert(**inputs)['pooler_output']\n",
2022-02-24 01:24:20 +09:00
"hidden.size()"
]
},
{
"cell_type": "code",
2022-02-27 19:50:13 +09:00
"execution_count": 12,
2022-02-24 01:24:20 +09:00
"id": "ae9f8fba",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2022-02-27 19:50:13 +09:00
"tensor([-0.0180, -0.0823, -0.0234, -0.0886], grad_fn=<SqueezeBackward0>)"
2022-02-24 01:24:20 +09:00
]
},
2022-02-27 19:50:13 +09:00
"execution_count": 12,
2022-02-24 01:24:20 +09:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"w = model.lin2(model.lin1(hidden)).squeeze()\n",
"w"
]
},
{
"cell_type": "code",
2022-02-27 19:50:13 +09:00
"execution_count": 13,
2022-02-24 01:24:20 +09:00
"id": "5470c3f8",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2022-02-27 19:50:13 +09:00
"tensor([0.4955, 0.4794, 0.4942, 0.4779], grad_fn=<SigmoidBackward0>)"
2022-02-24 01:24:20 +09:00
]
},
2022-02-27 19:50:13 +09:00
"execution_count": 13,
2022-02-24 01:24:20 +09:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.sigmoid(w)"
]
},
{
"cell_type": "code",
2022-02-27 19:50:13 +09:00
"execution_count": 14,
2022-02-24 01:24:20 +09:00
"id": "b7eb8e67",
"metadata": {},
"outputs": [],
"source": [
"labels = torch.tensor([1,0,0,1])"
]
},
{
"cell_type": "code",
2022-02-27 19:50:13 +09:00
"execution_count": 15,
2022-02-24 01:24:20 +09:00
"id": "7a324ed7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2022-02-27 19:50:13 +09:00
"tensor(0.6937, dtype=torch.float64,\n",
2022-02-24 01:24:20 +09:00
" grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)"
]
},
2022-02-27 19:50:13 +09:00
"execution_count": 15,
2022-02-24 01:24:20 +09:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"nn.BCEWithLogitsLoss()(w,labels.double())"
]
},
{
"cell_type": "code",
2022-02-27 19:50:13 +09:00
"execution_count": 16,
2022-02-24 01:24:20 +09:00
"id": "cb54294d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2022-02-27 19:50:13 +09:00
"tensor([False, True, True, False])"
2022-02-24 01:24:20 +09:00
]
},
2022-02-27 19:50:13 +09:00
"execution_count": 16,
2022-02-24 01:24:20 +09:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(w > 0).long() == labels"
]
},
{
"cell_type": "markdown",
"id": "596b89bd",
"metadata": {},
"source": [
"이런 일이 벌어짐. sigmoid 는 나중에"
]
},
2022-02-27 19:50:13 +09:00
{
"cell_type": "markdown",
"id": "9875ce48",
"metadata": {},
"source": [
"### Model Training"
]
},
2022-02-24 01:24:20 +09:00
{
"cell_type": "code",
"execution_count": 28,
"id": "769c4290",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"MyModel(\n",
" (bert): BertModel(\n",
" (embeddings): BertEmbeddings(\n",
" (word_embeddings): Embedding(119547, 768, padding_idx=0)\n",
" (position_embeddings): Embedding(512, 768)\n",
" (token_type_embeddings): Embedding(2, 768)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (encoder): BertEncoder(\n",
" (layer): ModuleList(\n",
" (0): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (1): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (2): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (3): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (4): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (5): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (6): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (7): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (8): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (9): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (10): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (11): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" )\n",
" (pooler): BertPooler(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (activation): Tanh()\n",
" )\n",
" )\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
2022-02-27 19:50:13 +09:00
" (lin1): Linear(in_features=768, out_features=768, bias=True)\n",
" (gelu): GELU()\n",
" (lin2): Linear(in_features=768, out_features=1, bias=True)\n",
2022-02-24 01:24:20 +09:00
")\n"
]
}
],
"source": [
"device = torch.device('cuda')\n",
"model.to(device)\n",
"print(model)"
]
},
{
"cell_type": "code",
2022-02-27 19:50:13 +09:00
"execution_count": 29,
2022-02-24 01:24:20 +09:00
"id": "b9380dcd",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"device(type='cuda', index=0)"
]
},
2022-02-27 19:50:13 +09:00
"execution_count": 29,
2022-02-24 01:24:20 +09:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"bert.device"
]
},
{
"cell_type": "markdown",
"id": "5e82df0e",
"metadata": {},
"source": [
"모델을 모두 gpu로 보냄"
]
},
{
"cell_type": "code",
2022-02-27 19:50:13 +09:00
"execution_count": 30,
2022-02-24 01:24:20 +09:00
"id": "74c4becc",
"metadata": {},
"outputs": [],
"source": [
"from torch.utils.data import Dataset, DataLoader\n",
"BATCH_SIZE = 16\n",
"train_loader = DataLoader(\n",
" dataTrain,\n",
" batch_size=BATCH_SIZE,\n",
" shuffle=True,\n",
" collate_fn=collate_fn\n",
")\n",
2022-02-27 19:50:13 +09:00
"dev_loader = DataLoader(\n",
" dataDev,\n",
" batch_size=BATCH_SIZE,\n",
" shuffle=True,\n",
" collate_fn=collate_fn\n",
")\n",
2022-02-24 01:24:20 +09:00
"test_loader = DataLoader(\n",
" dataTest,\n",
" batch_size=BATCH_SIZE,\n",
" shuffle=True,\n",
" collate_fn=collate_fn\n",
")"
]
},
{
"cell_type": "markdown",
"id": "4153b2e7",
"metadata": {},
"source": [
"데이터 모델 준비"
]
},
{
"cell_type": "code",
2022-02-27 19:50:13 +09:00
"execution_count": 31,
2022-02-24 01:24:20 +09:00
"id": "3cd5bf7b",
"metadata": {},
"outputs": [],
"source": [
"from torch.optim import AdamW\n",
"from groupby_index import groupby_index\n",
"from tqdm import tqdm"
]
},
{
"cell_type": "code",
2022-02-27 19:50:13 +09:00
"execution_count": 32,
2022-02-24 01:24:20 +09:00
"id": "65b5ccde",
"metadata": {},
"outputs": [],
"source": [
"optimizer = AdamW(model.parameters(), lr=1.0e-5)\n",
"BCELoss = nn.BCEWithLogitsLoss()"
]
},
{
"cell_type": "markdown",
"id": "79607e81",
"metadata": {},
"source": [
"학습 준비"
]
},
{
"cell_type": "code",
2022-02-27 19:50:13 +09:00
"execution_count": 33,
"id": "68183cf4",
"metadata": {},
"outputs": [],
"source": [
"def freezeParameter(model):\n",
" for param in model.bert.parameters():\n",
" param.requires_grad = False\n",
" for param in model.bert.embeddings.parameters():\n",
" param.requires_grad = True"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "42d1a29f",
"metadata": {},
"outputs": [],
"source": [
"#freezeParameter(model)"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "cc8a82a5",
"metadata": {},
"outputs": [],
"source": [
"TRAIN_EPOCH = 5\n",
"\n",
"result = []\n",
"iteration = 0"
]
},
{
"cell_type": "code",
"execution_count": 36,
2022-02-24 01:24:20 +09:00
"id": "4835a0d3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 0 start:\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2022-02-27 19:50:13 +09:00
"Epoch 0: 100%|███████████████████████████████████| 9375/9375 [11:41<00:00, 13.37minibatch/s, accuracy=0.844, loss=5.79]\n"
2022-02-24 01:24:20 +09:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 1 start:\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2022-02-27 19:50:13 +09:00
"Epoch 1: 100%|████████████████████████████████████| 9375/9375 [11:48<00:00, 13.23minibatch/s, accuracy=0.84, loss=5.37]\n"
2022-02-24 01:24:20 +09:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 2 start:\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2022-02-27 19:50:13 +09:00
"Epoch 2: 100%|███████████████████████████████████| 9375/9375 [11:44<00:00, 13.31minibatch/s, accuracy=0.879, loss=4.79]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 3 start:\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 3: 100%|███████████████████████████████████| 9375/9375 [11:42<00:00, 13.35minibatch/s, accuracy=0.852, loss=5.12]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 4 start:\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 4: 8%|██▊ | 717/9375 [00:53<10:48, 13.36minibatch/s, accuracy=0.918, loss=3.09]\n"
2022-02-24 01:24:20 +09:00
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
2022-02-27 19:50:13 +09:00
"\u001b[1;32m~\\AppData\\Local\\Temp/ipykernel_29640/636965671.py\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m 17\u001b[0m \u001b[0mbatch_labels\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmini_l\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 18\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 19\u001b[1;33m \u001b[0moutput\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m**\u001b[0m\u001b[0mbatch_inputs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 20\u001b[0m \u001b[0mloss\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mBCELoss\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbatch_labels\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdouble\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 21\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\anaconda3\\envs\\nn\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1100\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m 1101\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1102\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1103\u001b[0m \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1104\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\AppData\\Local\\Temp/ipykernel_29640/2430988958.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, **kargs)\u001b[0m\n\u001b[0;32m 11\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 12\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m**\u001b[0m\u001b[0mkargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 13\u001b[1;33m \u001b[0memb\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbert\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m**\u001b[0m\u001b[0mkargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 14\u001b[0m \u001b[0me1\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdropout\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0memb\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;34m'pooler_output'\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 15\u001b[0m \u001b[0me2\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mgelu\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlin1\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0me1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\anaconda3\\envs\\nn\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1100\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m 1101\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1102\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1103\u001b[0m \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1104\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\anaconda3\\envs\\nn\\lib\\site-packages\\transformers\\models\\bert\\modeling_bert.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[0;32m 1004\u001b[0m \u001b[0moutput_attentions\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0moutput_attentions\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1005\u001b[0m \u001b[0moutput_hidden_states\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0moutput_hidden_states\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1006\u001b[1;33m \u001b[0mreturn_dict\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mreturn_dict\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1007\u001b[0m )\n\u001b[0;32m 1008\u001b[0m \u001b[0msequence_output\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mencoder_outputs\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\anaconda3\\envs\\nn\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1100\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m 1101\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1102\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1103\u001b[0m \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1104\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\anaconda3\\envs\\nn\\lib\\site-packages\\transformers\\models\\bert\\modeling_bert.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[0;32m 590\u001b[0m \u001b[0mencoder_attention_mask\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 591\u001b[0m \u001b[0mpast_key_value\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 592\u001b[1;33m \u001b[0moutput_attentions\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 593\u001b[0m )\n\u001b[0;32m 594\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\anaconda3\\envs\\nn\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1100\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m 1101\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1102\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1103\u001b[0m \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1104\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\anaconda3\\envs\\nn\\lib\\site-packages\\transformers\\models\\bert\\modeling_bert.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)\u001b[0m\n\u001b[0;32m 475\u001b[0m \u001b[0mhead_mask\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 476\u001b[0m \u001b[0moutput_attentions\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0moutput_attentions\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 477\u001b[1;33m \u001b[0mpast_key_value\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mself_attn_past_key_value\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 478\u001b[0m )\n\u001b[0;32m 479\u001b[0m \u001b[0mattention_output\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself_attention_outputs\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\anaconda3\\envs\\nn\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1100\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m 1101\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1102\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1103\u001b[0m \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1104\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\anaconda3\\envs\\nn\\lib\\site-packages\\transformers\\models\\bert\\modeling_bert.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)\u001b[0m\n\u001b[0;32m 407\u001b[0m \u001b[0mencoder_attention_mask\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 408\u001b[0m \u001b[0mpast_key_value\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 409\u001b[1;33m \u001b[0moutput_attentions\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 410\u001b[0m )\n\u001b[0;32m 411\u001b[0m \u001b[0mattention_output\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0moutput\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself_outputs\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mhidden_states\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\anaconda3\\envs\\nn\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1100\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m 1101\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1102\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1103\u001b[0m \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1104\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\anaconda3\\envs\\nn\\lib\\site-packages\\transformers\\models\\bert\\modeling_bert.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)\u001b[0m\n\u001b[0;32m 304\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 305\u001b[0m \u001b[1;31m# Take the dot product between \"query\" and \"key\" to get the raw attention scores.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 306\u001b[1;33m \u001b[0mattention_scores\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmatmul\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mquery_layer\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mkey_layer\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtranspose\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m-\u001b[0m\u001b[1;36m2\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 307\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 308\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mposition_embedding_type\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;34m\"relative_key\"\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mposition_embedding_type\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;34m\"relative_key_query\"\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
2022-02-24 01:24:20 +09:00
"\u001b[1;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"model.zero_grad()\n",
"\n",
"for epoch in range(TRAIN_EPOCH):\n",
" model.train()\n",
" print(f\"epoch {epoch} start:\")\n",
" with tqdm(train_loader, unit=\"minibatch\") as tepoch:\n",
" tepoch.set_description(f\"Epoch {epoch}\")\n",
" \n",
2022-02-27 19:50:13 +09:00
" for batch in groupby_index(tepoch,16):\n",
2022-02-24 01:24:20 +09:00
" corrects = 0\n",
" totals = 0\n",
" losses = 0\n",
" \n",
" optimizer.zero_grad()\n",
" for mini_i,mini_l in batch:\n",
" batch_inputs = {k: v.to(device) for k, v in list(mini_i.items())}\n",
" batch_labels = mini_l.to(device)\n",
" \n",
" output = model(**batch_inputs)\n",
" loss = BCELoss(output, batch_labels.double())\n",
" \n",
" prediction = (output > 0).to(device,dtype=torch.int64)\n",
" corrects += (prediction == batch_labels).sum().item()\n",
" totals += prediction.size()[0]\n",
" losses += loss.item()\n",
" loss.backward()\n",
"\n",
" optimizer.step()\n",
" accuracy = corrects / totals\n",
" result.append({\"iter\":iteration,\"loss\":losses,\"accuracy\":accuracy})\n",
" tepoch.set_postfix(loss=losses, accuracy= accuracy)\n",
" iteration += 1"
]
},
{
"cell_type": "code",
2022-02-27 19:50:13 +09:00
"execution_count": 18,
2022-02-24 01:24:20 +09:00
"id": "81b69931",
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
2022-02-27 19:50:13 +09:00
"execution_count": 37,
2022-02-24 01:24:20 +09:00
"id": "c3a73c68",
"metadata": {},
"outputs": [
{
"data": {
2022-02-27 19:50:13 +09:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYcAAAD4CAYAAAAHHSreAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAABPbUlEQVR4nO2dd5gUxdPHv3UBOA7JOQfJICgZRJIkEURFRFSCoIKKGQVUggnDKxhAkR8iCKiABBGRHEWiBEmSM0eQDEe6u37/6J3d2cmzOxtmrz/Pc8/O9vT09NzudnVXVVcRYwwCgUAgEMiJi3QHBAKBQBB9COEgEAgEAhVCOAgEAoFAhRAOAoFAIFAhhINAIBAIVCSE82ZxcXEsKSkpnLcUCAQC15OamsoYY2GdzIdVOCQlJeHq1avhvKVAIBC4HiK6Fu57CrWSQCAQCFQI4SAQCAQCFUI4CAQCgUCFEA4CgUAgUCGEg0AgEAhUCOEgEAgEAhVCOAgEAoFAhSuEw+mrpzF9x/RId0OQWTh+HJg7N9K9EAgiiiuEw8PTHkbnXzqj64yuuHTjUqS7I4h1GjQA2rePdC8EgohC4Uz2k5yczALZIU3DyO89GyISFAlCCHm+byIRliBKIKJUxlhyOO/pipXDO/e8E+kuCDIjQjgIMjGuEA7vNnvX771IbSoIC+J7JsjEuEI4AMDsR2d7j99d8a5+RYHAKYRwEGRiXGFzkJDbHoTdQRAyJJvDzZtAYmJk+yIQQNgcTMmWkC3SXRBkJsTKQZCJcZVwmNZpWqS7IMhMZGREugeCWOD6dWC6+/ZpuUo4tK/o8z0/eP5gBHsiyBSIlYPACd54A+jcGVi5MtI9sYWrhIOcwxcPR7oLglhHrByco2dP4P/+z/51aWlA3brAokX6ddatA+64A0hNNW9vyhSgbVv7/QiGo0f56/nz4b1vkLhWODSb2CzSXRDEOpl95XD1KnDypLr8xAmuKrHDhAlA//7AtWtASopx3YwM4NAhfpySAmzYADz1lH79V18Ftm0DNm/WPn/jhm+AfuIJYP584NgxXm7EgQPG583qp6fz55AcHFw22XCtcBAIQo7LfsyO06ABUKSIurxYMeChhwJrs3VroGhR4zoffACUKQPs3esrMxLUZkK8Rw+gZEnufSZRogTQrZv+NYsWAeXKAT/+aNy2xB9/8PrTZHbRwYP5cxw5wt+77PskhINAoEcgK4dRo4Dt253vixzGgPff57PfULJtm+941ixg4ULf+z/+UNefNw+YPdu4zVWrzO87eDB/PXAAeOstfnz8uP/grsVbb2mrlqQ+paVpl0vMnMlXFQDwzz/89e+/zfurrH/kCFeFffghLztzhr8K4RBaNj2zKdJdEMQy1675jgP5MffrB9So4Vx/tNixA3jnHW7kDBcPPcRn/Ua0awc8+KBz95w4EZg0yff+hx+060lCfMUK4OOP1eelz9FM2D/8sM8eIV0TF8AQ2b49V4VJCLVSeCiXt1ykuyCIZSpV8h0HanNwehBo04arQSSkGbA0Sx40yDcAKTl1ip/7+Wft8x9/zM/fumWtL/L/iZVVgB7//stfP/lE//4//eT/Xlnn22/5tWvX+spSU/mKg8i3CkhP5685cvhfr/f5EnEPI+n46FH+OmeOuu6AAer//blz/u8PHza+X5TiOuGQGOfbsSpiLAkcR9IPA8EN8leu8NcNG3zG0EBZsEBbhSR9/4cP1792507++u232uffe4+/WjUwy39zwfjuf/65//0PHeKD/IkT+tcoZ/Gffaauc/KkT1hIKw29z/HWLe69tGOH/j1PnQJGj+bH48apzytXKidO6Kv75BMPF+A64ZAlPov3eP3x9RHsiSDmCWby0aMHf61bFyhd2one+NBbJWhh9gzSeattOrUqkoSV1F6DBvzPaABV9lGrL5Mn+4SIFXXSE08A1arpn//hB58AkFYgRkyerH+ufHnz66MI1wmH+Lh473H97+pHsCcCx/nyS99yPhqQDz7Vq/PBycgoKh+EZswALl70b2fECJ+BFeBqDyI+i27cGLh8Wd3mhQu+48aNgV279O//wgvqsvff56/LlwObNOx1kjunVcOrPNaU3qArDaJDhwKffmrcntTG2bP8Vet/IHHtGrB4Mf8/EAH792vX+/pr/jp1KvDSS8b3l7j7bvM6mzcDTZrw1Y5kNJewMpGIjzevE00wxsL2lz17duYE3Wd1ZxgKhqFwpD2BA6SnM5aSElwb/Cdm/7rjxxnLyAju3so+AIwdPcrYqVP+5UuXMpaaqr7u6lXGzp71v37iRP9nUj6fvC7A2PTp6nbff9+/TocOjG3ezI/vuEPdjpyLF/3P1avH2H//MXb9uroPFSvq/y+0+gow9sILjJ04wf9P8jrXrzN24YLv/blz2tczxljWrNrntP5mzGAsKcl6/VD/yZ954EDz+rdu6X7tzABwlYVpnJb+XLdyAID7K9wf6S4IlLz/PveJl+vsw8HWrdzvfswY59t++mmgUCH/1ULz5kDt2uq6NWsC+fL5lzGT2aRSTaKcWZ4+Dbz9tv41ZqqgXLnU7efPzw3cTpCayvcslCjBPYskMjKA3Ll97/Pm1b5+2TJrqhqJW7eiK0qu5L4KGNt9JEK4ciCiNkS0m4j2EdEAjfN5iGgWEf1DROuJyECXxnGlcMiRJYd5pcxOejrfTGS0THeSefP46/HjobsHY8BHH/l7g+zezV+XLQuszf37gW++0T4nebssX+5fLhl55cg3bEmYCQflebnBddEibddNIm3DqBWk9pXPY4ae4JX+94Dv8wes7/No3ly998CIdeuAS1GUQ/4dmxkq7diKbDVL8QBGA2gLoAqAx4ioiqLaIABbGGN3AOgG4Auzdl0pHO4uaUE/mNmZMYPPOt98M7z3DdEPAACwdCkwcCDQp4/6fmYDsR6NGwPPPWfsrWPm36+HvE/K/mkNinLh0KoVDzehVUfynrFLoDPXvn21y1ev9h3LP/e6dQO7jxkjR4am3UDRcm2NDHUB7GOMHWCM3QTwM4AHFHWqAFgCAIyxfwGUJqJCRo26UjjIVw4s0EHBrfToAZQqZV5PMjTqzbS6duU/6Oef56/Sbs5ACcfnIKl35M+k9Eyxi7QKsdv/nj3N6xgJBy11SocOfO+AkYA1E7579/I6WvXkwoHIf5dzsJ/f1KnBXS8wI4GINsr+npGdKwZA7i99zFMmZyuAhwCAiOoCKAWguNENXSkc5Hy+9vNIdyG8TJxoTa9vNqOWNhhJnh1yL5pAkO6zYQMfcA8dMvasAYA1a3wePXIOHvQd79jh2yegNeAZCYcjR3wqoOXLtVcHUpsLFtgTMBMm8Gf88kt+rRby/71c3bZ4sb7XU4cOxvc1C1onD3GhRLlPQK4u2rMHePJJvneBMf/QGYJoII0xVlv2N1Z2TmvGoPzhfwQgDxFtAdAPwGYAhjo9U+FAROOJ6DQRbZeV5SWiRUS01/Oax6ydUDFl25RI3Tq6kQaCcK2spPu8+CLXJZcpA1RRqj1l3LgBNGyoPRiWLes7rlaNB03Tw0gIlioFVK3KhVSzZrxvetc/+KBvY5ZVypThrpJ6Bl55n+T/i5YtuZunFnLXVS3++st3bFeFZ6ZWmjyZh+SYOZOHwBa4hWMAZFvoURyA325CxtglxlhPxlhNcJtDAQAHYYCVlcMEAMpv/wAASxhj5cH1WCrreKipX5zvcaBQ6rjdSloa8Pjj/Pinn7Rn53bYuJGroazOrLduNa8jhUKQ/OulnbJGLFnCXxcs4MJl7lxfLJ9ff/XVe+89f+8ZKY6+1mxY/v05aPhbsY98li/tmJaw8j9yGuXKQS/8xb59oe+LwEk2AChPRGWIKAuALgD8DCJElNtzDgB6A1jJGDO07psKB8bYSgCKYCF4AID065sIoKNp9x1GCqOx8cRGrDi0Ity3jzznz/s2Dsm5dk098Fg1YEpqlwsXuGpIMpp26MCFTEoKH5SVAx1gHBv/2jX/gHaAT9Bcvcpn2MpNRVrIk8WsXcsDnGkxeLBvhzLgv4q6dcvnwZWW5h/F87//zPtgByOBZxSyIRi
2022-02-24 01:24:20 +09:00
"text/plain": [
"<Figure size 432x288 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"iters = [item[\"iter\"] for item in result]\n",
"fig, ax1 = plt.subplots()\n",
"ax1.plot(iters,[item[\"loss\"] for item in result],'g')\n",
"ax2 = ax1.twinx()\n",
"ax2.plot(iters,[item[\"accuracy\"] for item in result],'r')\n",
"plt.xlabel(\"iter\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
2022-02-27 19:50:13 +09:00
"execution_count": 45,
2022-02-24 01:24:20 +09:00
"id": "cab7889f",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2022-02-27 19:50:13 +09:00
"gpu allocated : 2778 MB\n",
"gpu reserved : 3272 MB\n"
2022-02-24 01:24:20 +09:00
]
}
],
"source": [
"torch.cuda.empty_cache()\n",
"print(f\"gpu allocated : {torch.cuda.memory_allocated() // 1024**2} MB\")\n",
"print(f\"gpu reserved : {torch.cuda.memory_reserved() // 1024 ** 2} MB\")"
]
},
{
"cell_type": "code",
2022-02-27 19:50:13 +09:00
"execution_count": 46,
2022-02-24 01:24:20 +09:00
"id": "29ffab84",
"metadata": {},
"outputs": [],
"source": [
"torch.save(model.state_dict(), \"model.zip\")"
]
},
{
"cell_type": "code",
2022-02-27 19:50:13 +09:00
"execution_count": 21,
2022-02-24 01:24:20 +09:00
"id": "4b9b9579",
"metadata": {},
"outputs": [],
"source": [
"del batch_inputs\n",
"del batch_labels\n",
"del loss\n",
"del optimizer"
]
},
{
"cell_type": "code",
2022-02-27 19:50:13 +09:00
"execution_count": 88,
"id": "80d0ee50",
2022-02-24 01:24:20 +09:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2022-02-27 19:50:13 +09:00
"<All keys matched successfully>"
2022-02-24 01:24:20 +09:00
]
},
2022-02-27 19:50:13 +09:00
"execution_count": 88,
2022-02-24 01:24:20 +09:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
2022-02-27 19:50:13 +09:00
"model.load_state_dict(torch.load(\"model.zip\"))"
2022-02-24 01:24:20 +09:00
]
},
{
"cell_type": "code",
2022-02-27 19:50:13 +09:00
"execution_count": 89,
"id": "5c9b570c",
2022-02-24 01:24:20 +09:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"MyModel(\n",
" (bert): BertModel(\n",
" (embeddings): BertEmbeddings(\n",
" (word_embeddings): Embedding(119547, 768, padding_idx=0)\n",
" (position_embeddings): Embedding(512, 768)\n",
" (token_type_embeddings): Embedding(2, 768)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (encoder): BertEncoder(\n",
" (layer): ModuleList(\n",
" (0): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (1): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (2): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (3): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (4): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (5): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (6): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (7): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (8): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (9): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (10): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (11): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" )\n",
" (pooler): BertPooler(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (activation): Tanh()\n",
" )\n",
" )\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
2022-02-27 19:50:13 +09:00
" (lin1): Linear(in_features=768, out_features=768, bias=True)\n",
" (gelu): GELU()\n",
" (lin2): Linear(in_features=768, out_features=1, bias=True)\n",
2022-02-24 01:24:20 +09:00
")"
]
},
2022-02-27 19:50:13 +09:00
"execution_count": 89,
2022-02-24 01:24:20 +09:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
2022-02-27 19:50:13 +09:00
"device = torch.device('cuda')\n",
"model.to(device)"
2022-02-24 01:24:20 +09:00
]
},
{
"cell_type": "code",
2022-02-27 19:50:13 +09:00
"execution_count": 38,
"id": "fff7a7d0",
2022-02-24 01:24:20 +09:00
"metadata": {},
"outputs": [
{
2022-02-27 19:50:13 +09:00
"name": "stderr",
"output_type": "stream",
"text": [
"100%|███████████████████████████████████████████████████████████████████████████| 1250/1250 [00:34<00:00, 36.39batch/s]\n"
2022-02-24 01:24:20 +09:00
]
}
],
"source": [
2022-02-27 19:50:13 +09:00
"model.eval()\n",
"collect_list = []\n",
"with torch.no_grad():\n",
" with tqdm(dev_loader, unit=\"batch\") as tepoch:\n",
" for batch_i,batch_l in tepoch:\n",
" batch_inputs = {k: v.cuda(device) for k, v in list(batch_i.items())}\n",
" batch_labels = batch_l.cuda(device)\n",
" output = model(**batch_inputs)\n",
" loss = BCELoss(output, batch_labels.double())\n",
" \n",
" prediction = (output > 0).to(device,dtype=torch.int64)\n",
" correct = (prediction == batch_labels).sum().item()\n",
" accuracy = correct / prediction.size()[0]\n",
" \n",
" collect_list.append({\"loss\":loss.item(),\"accuracy\":accuracy, \"batch_size\":batch_labels.size(0),\n",
" \"predict\":prediction.cpu(),\n",
" \"actual\":batch_labels.cpu()})"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "4e9a90b5",
"metadata": {},
"outputs": [],
"source": [
"def getConfusionMatrix(predict,actual):\n",
" ret = torch.zeros((2,2),dtype=torch.long)\n",
" for p_s,a_s in zip(predict,actual):\n",
" ret[p_s,a_s] += 1\n",
" return ret"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "b7a513c9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"average_loss : 0.33939967472716237, average_accuracy : 0.86565, size :20000\n"
]
}
],
"source": [
"total_loss = 0\n",
"total_accuracy = 0\n",
"total_size = 0\n",
"confusion = torch.zeros((2,2),dtype=torch.long)\n",
"\n",
"for item in collect_list:\n",
" batch_size = item[\"batch_size\"]\n",
" total_loss += batch_size * item[\"loss\"]\n",
" total_accuracy += batch_size * item[\"accuracy\"]\n",
" total_size += batch_size\n",
" confusion += getConfusionMatrix(item[\"predict\"],item[\"actual\"])\n",
"print(f\"\"\"average_loss : {total_loss/total_size}, average_accuracy : {total_accuracy/total_size}, size :{total_size}\"\"\")"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "1ac327de",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[8716, 1638],\n",
" [1194, 8452]])"
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"confusion"
]
},
{
"cell_type": "code",
"execution_count": 122,
"id": "3e71d4d2",
"metadata": {},
"outputs": [],
"source": [
"def getF1Score(confusion,c):\n",
" TP = confusion[c,c]\n",
" FP = confusion[c].sum() - TP\n",
" FN = confusion[:,c].sum() - TP\n",
" precision = TP / (TP + FP)\n",
" recall = TP / (TP + FN)\n",
"\n",
" f1Score = (2*precision*recall)/(precision + recall)\n",
" return f1Score"
]
},
{
"cell_type": "code",
"execution_count": 242,
"id": "6756408c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"f1 score : 0.7763830423355103\n"
]
}
],
"source": [
"print(f\"f1 score : {getF1Score(confusion,1)}\")"
]
},
{
"cell_type": "code",
"execution_count": 86,
"id": "f28f64e9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"MyModel(\n",
" (bert): BertModel(\n",
" (embeddings): BertEmbeddings(\n",
" (word_embeddings): Embedding(119547, 768, padding_idx=0)\n",
" (position_embeddings): Embedding(512, 768)\n",
" (token_type_embeddings): Embedding(2, 768)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (encoder): BertEncoder(\n",
" (layer): ModuleList(\n",
" (0): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (1): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (2): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (3): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (4): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (5): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (6): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (7): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (8): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (9): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (10): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (11): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" )\n",
" (pooler): BertPooler(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (activation): Tanh()\n",
" )\n",
" )\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (lin1): Linear(in_features=768, out_features=768, bias=True)\n",
" (gelu): GELU()\n",
" (lin2): Linear(in_features=768, out_features=1, bias=True)\n",
")"
]
},
"execution_count": 86,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.eval()"
]
},
{
"cell_type": "markdown",
"id": "2da5789b",
"metadata": {},
"source": [
"한번 테스트해보기"
]
},
{
"cell_type": "code",
"execution_count": 51,
"id": "cc727fd9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"웹 gpu\n",
"긍정적 output : 54.185086488723755 %\n",
"부정적 output : 45.814913511276245 %\n"
]
}
],
"source": [
"sen = input()\n",
"inputs = tokenizer(sen, return_tensors = 'pt', padding='longest')\n",
2022-02-24 01:24:20 +09:00
"output = model(**{k: v.to(device) for k,v in inputs.items() })\n",
"prob = torch.sigmoid(output).item()\n",
"print(\"긍정적 output :\",prob * 100,\"%\")\n",
"print(\"부정적 output :\", (1-prob) * 100,\"%\")"
]
},
{
"cell_type": "markdown",
"id": "2faa8141",
"metadata": {},
"source": [
"```\n",
"5471412\t맘에 들어요~ 0\n",
"```\n",
"라벨이 잘못 붙어있는 것들이 있다. 별점가지고만 긍정, 부정을 매긴 것 같다."
]
},
{
"cell_type": "code",
2022-02-27 19:50:13 +09:00
"execution_count": 40,
2022-02-24 01:24:20 +09:00
"id": "b40f071c",
"metadata": {},
2022-02-27 19:50:13 +09:00
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|███████████████████████████████████████████████████████████████████████████| 1875/1875 [00:51<00:00, 36.45batch/s]\n"
]
}
],
"source": [
"model.eval()\n",
"collect_list = []\n",
"with torch.no_grad():\n",
" with tqdm(test_loader, unit=\"batch\") as tepoch:\n",
" for batch_i,batch_l in tepoch:\n",
" batch_inputs = {k: v.cuda(device) for k, v in list(batch_i.items())}\n",
" batch_labels = batch_l.cuda(device)\n",
" output = model(**batch_inputs)\n",
" loss = BCELoss(output, batch_labels.double())\n",
" \n",
" prediction = (output > 0).to(device,dtype=torch.int64)\n",
" correct = (prediction == batch_labels).sum().item()\n",
" accuracy = correct / prediction.size()[0]\n",
" \n",
" collect_list.append({\"loss\":loss.item(),\"accuracy\":accuracy, \"batch_size\":batch_labels.size(0),\n",
" \"predict\":prediction.cpu(),\n",
" \"actual\":batch_labels.cpu()})"
]
},
{
"cell_type": "code",
"execution_count": 52,
"id": "871b72d4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"average_loss : 0.3305340794156926, average_accuracy : 0.8672333333333333, size :30000\n"
]
}
],
"source": [
"total_loss = 0\n",
"total_accuracy = 0\n",
"total_size = 0\n",
"confusion = torch.zeros((2,2),dtype=torch.long)\n",
"\n",
"for item in collect_list:\n",
" batch_size = item[\"batch_size\"]\n",
" total_loss += batch_size * item[\"loss\"]\n",
" total_accuracy += batch_size * item[\"accuracy\"]\n",
" total_size += batch_size\n",
" confusion += getConfusionMatrix(item[\"predict\"],item[\"actual\"])\n",
"print(f\"\"\"average_loss : {total_loss/total_size}, average_accuracy : {total_accuracy/total_size}, size :{total_size}\"\"\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "62333944",
"metadata": {},
2022-02-24 01:24:20 +09:00
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}