feat: batch and collate function
This commit is contained in:
parent
58fba0bd3c
commit
fbc8a25e30
153
Batch.ipynb
Normal file
153
Batch.ipynb
Normal file
@ -0,0 +1,153 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "c916dd3b",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'C:\\\\Users\\\\Monoid\\\\anaconda3\\\\envs\\\\nn\\\\python.exe'"
|
||||
]
|
||||
},
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import sys\n",
|
||||
"sys.executable"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "d5861234",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from typing import List\n",
|
||||
"from torch.utils.data import Dataset\n",
|
||||
"import torch\n",
|
||||
"from transformers import PreTrainedTokenizer\n",
|
||||
"from ndata import readNsmcRawData, NsmcRawData"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "5accd3a9",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"load bert tokenizer...\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████████████████████████████████████████████████████████████████| 150000/150000 [00:00<00:00, 205761.43it/s]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from transformers import BertTokenizer\n",
|
||||
"print(\"load bert tokenizer...\")\n",
|
||||
"PRETAINED_MODEL_NAME = 'bert-base-multilingual-cased'\n",
|
||||
"tokenizer = BertTokenizer.from_pretrained(PRETAINED_MODEL_NAME)\n",
|
||||
"\n",
|
||||
"data = readNsmcRawData(\"nsmc/nsmc-master/ratings_train.txt\",use_tqdm=True,total=150000)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d10fcb83",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"data를 준비"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "552fe555",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"({'input_ids': tensor([[ 101, 9519, 9074, 119005, 119, 119, 9708, 119235, 9715,\n",
|
||||
" 119230, 16439, 77884, 48549, 9284, 22333, 12692, 102, 0,\n",
|
||||
" 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||||
" 0, 0, 0, 0],\n",
|
||||
" [ 101, 100, 119, 119, 119, 9928, 58823, 30005, 11664,\n",
|
||||
" 9757, 118823, 30858, 18227, 119219, 119, 119, 119, 119,\n",
|
||||
" 9580, 41605, 25486, 12310, 20626, 23466, 8843, 118986, 12508,\n",
|
||||
" 9523, 17196, 16439, 102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||||
" 0, 0, 0, 0, 0, 0, 0],\n",
|
||||
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||||
" 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,\n",
|
||||
" 0, 0, 0, 0, 0, 0, 0],\n",
|
||||
" [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
||||
" 1, 1, 1, 1, 1, 1, 1]])}, tensor([0, 1]))\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"def make_collate_fn(tokenzier: PreTrainedTokenizer):\n",
|
||||
" def collate_fn(batch: List[NsmcRawData]):\n",
|
||||
" labels = [s.label for s in batch]\n",
|
||||
" return tokenizer([s.document for s in batch], return_tensors='pt', padding='longest', truncation=True), torch.tensor(labels)\n",
|
||||
" return collate_fn\n",
|
||||
"\n",
|
||||
"collate = make_collate_fn(tokenizer)\n",
|
||||
"print(collate(data[0:2]))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1cff8e03",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"간단한 collate function"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "89eb64d8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.11"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
41
ndataset.py
Normal file
41
ndataset.py
Normal file
@ -0,0 +1,41 @@
|
||||
import sys
|
||||
from typing import List
|
||||
from torch.utils.data import Dataset
|
||||
import torch
|
||||
from transformers import PreTrainedTokenizer
|
||||
from ndata import readNsmcRawData, NsmcRawData
|
||||
|
||||
def readNsmcDataAll():
|
||||
"""
|
||||
Returns: train, test
|
||||
"""
|
||||
print("read train set", file=sys.stderr)
|
||||
train = readNsmcRawData("nsmc/nsmc-master/ratings_train.txt",use_tqdm=True,total=150_000)
|
||||
print("read test set", file=sys.stderr)
|
||||
test = readNsmcRawData("nsmc/nsmc-master/ratings_test.txt",use_tqdm=True,total=50_000)
|
||||
return NsmcDataset(train),NsmcDataset(test)
|
||||
|
||||
class NsmcDataset(Dataset):
|
||||
def __init__(self, data: List[NsmcRawData]):
|
||||
self.x = data
|
||||
def __len__(self):
|
||||
return len(self.x)
|
||||
def __getitem__(self, idx):
|
||||
return self.x[idx]
|
||||
|
||||
def make_collate_fn(tokenzier: PreTrainedTokenizer):
|
||||
def collate_fn(batch: List[NsmcRawData]):
|
||||
labels = [s.label for s in batch]
|
||||
return tokenizer([s.document for s in batch], return_tensors='pt', padding='longest', truncation=True), torch.tensor(labels)
|
||||
return collate_fn
|
||||
|
||||
if __name__ == "__main__":
|
||||
from transformers import BertTokenizer
|
||||
print("load bert tokenizer...")
|
||||
PRETAINED_MODEL_NAME = 'bert-base-multilingual-cased'
|
||||
tokenizer = BertTokenizer.from_pretrained(PRETAINED_MODEL_NAME)
|
||||
|
||||
data = readNsmcRawData("nsmc/nsmc-master/ratings_train.txt",use_tqdm=True,total=150000)
|
||||
|
||||
collate = make_collate_fn(tokenizer)
|
||||
print(collate(data[0:2]))
|
Loading…
Reference in New Issue
Block a user