nsmc-study/ndataset.py

44 lines
1.5 KiB
Python
Raw Permalink Normal View History

2022-02-23 19:46:29 +09:00
import sys
from typing import List
from torch.utils.data import Dataset
import torch
from transformers import PreTrainedTokenizer
from ndata import readNsmcRawData, NsmcRawData
def readNsmcDataAll():
"""
2022-02-27 19:50:13 +09:00
Returns: train, dev, test
2022-02-23 19:46:29 +09:00
"""
print("read train set", file=sys.stderr)
train = readNsmcRawData("nsmc/nsmc-master/ratings_train.txt",use_tqdm=True,total=150_000)
print("read test set", file=sys.stderr)
2022-02-27 19:50:13 +09:00
testBig = readNsmcRawData("nsmc/nsmc-master/ratings_test.txt",use_tqdm=True,total=50_000)
test = testBig[:30_000]
dev = testBig[30_000:]
return NsmcDataset(train),NsmcDataset(dev),NsmcDataset(test)
2022-02-23 19:46:29 +09:00
class NsmcDataset(Dataset):
def __init__(self, data: List[NsmcRawData]):
self.x = data
def __len__(self):
return len(self.x)
def __getitem__(self, idx):
return self.x[idx]
2022-02-23 20:38:28 +09:00
def make_collate_fn(tokenizer: PreTrainedTokenizer):
2022-02-23 19:46:29 +09:00
def collate_fn(batch: List[NsmcRawData]):
labels = [s.label for s in batch]
return tokenizer([s.document for s in batch], return_tensors='pt', padding='longest', truncation=True), torch.tensor(labels)
return collate_fn
if __name__ == "__main__":
from transformers import BertTokenizer
print("load bert tokenizer...")
PRETAINED_MODEL_NAME = 'bert-base-multilingual-cased'
tokenizer = BertTokenizer.from_pretrained(PRETAINED_MODEL_NAME)
data = readNsmcRawData("nsmc/nsmc-master/ratings_train.txt",use_tqdm=True,total=150000)
collate = make_collate_fn(tokenizer)
print(collate(data[0:2]))