feat: raw nsmc data reader
This commit is contained in:
parent
406ff30715
commit
54024fa7fd
45
ndata.py
45
ndata.py
@ -1,4 +1,49 @@
|
||||
from io import TextIOWrapper
|
||||
from typing import List, Union
|
||||
import os
|
||||
import csv
|
||||
from dataclasses import dataclass
|
||||
import tqdm
|
||||
@dataclass
|
||||
class NsmcRawData:
|
||||
id: int
|
||||
document: str
|
||||
label: int
|
||||
|
||||
class NsmcRawDataReader:
|
||||
def __init__(self, file: Union[str, TextIOWrapper]):
|
||||
self.fp = file
|
||||
self.need_close = isinstance(file,str)
|
||||
if self.need_close:
|
||||
self.fp = open(file,"r",encoding="utf-8",newline='\n')
|
||||
self.rd = csv.DictReader(self.fp,delimiter='\t')
|
||||
|
||||
def __iter__(self):
|
||||
mapper = lambda data: NsmcRawData(int(data["id"]),data["document"],int(data["label"]))
|
||||
return iter(map(mapper,self.rd))
|
||||
|
||||
def close(self):
|
||||
if self.need_close:
|
||||
self.fp.close()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
|
||||
def readNsmcRawData(file: Union[str, TextIOWrapper], use_tqdm = False, total: int = 0) -> List[NsmcRawData]:
|
||||
dataset = []
|
||||
with NsmcRawDataReader(file) as dataReader:
|
||||
if use_tqdm and total > 0:
|
||||
for d in tqdm.tqdm(dataReader, total=total):
|
||||
dataset.append(d)
|
||||
else:
|
||||
for data in dataReader:
|
||||
dataset.append(data)
|
||||
return dataset
|
||||
|
||||
BASE_PATH = "nsmc/nsmc-master"
|
||||
|
||||
if __name__ == "__main__":
|
||||
dataset = []
|
||||
raw = readNsmcRawData(f"{BASE_PATH}/ratings.txt", use_tqdm= True, total = 200000)
|
17
ndata.test.py
Normal file
17
ndata.test.py
Normal file
@ -0,0 +1,17 @@
|
||||
import unittest
|
||||
from ndata import *
|
||||
import io
|
||||
|
||||
class Testing(unittest.TestCase):
|
||||
def testcase(self):
|
||||
text = """id\tdocument\tlabel
|
||||
20\t사랑해요\t1"""
|
||||
textfile = io.StringIO(text)
|
||||
datas = readNsmcRawData(textfile)
|
||||
i = datas[0]
|
||||
self.assertEqual(i.id,20)
|
||||
self.assertEqual(i.document,"사랑해요")
|
||||
self.assertEqual(i.label,1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue
Block a user