From 54024fa7fd6e456ac92a4f42b668e168ff54562f Mon Sep 17 00:00:00 2001 From: monoid Date: Wed, 23 Feb 2022 19:22:30 +0900 Subject: [PATCH] feat: raw nsmc data reader --- ndata.py | 45 +++++++++++++++++++++++++++++++++++++++++++++ ndata.test.py | 17 +++++++++++++++++ 2 files changed, 62 insertions(+) create mode 100644 ndata.test.py diff --git a/ndata.py b/ndata.py index 5be1fb2..b5afda2 100644 --- a/ndata.py +++ b/ndata.py @@ -1,4 +1,49 @@ +from io import TextIOWrapper +from typing import List, Union import os +import csv +from dataclasses import dataclass +import tqdm +@dataclass +class NsmcRawData: + id: int + document: str + label: int + +class NsmcRawDataReader: + def __init__(self, file: Union[str, TextIOWrapper]): + self.fp = file + self.need_close = isinstance(file,str) + if self.need_close: + self.fp = open(file,"r",encoding="utf-8",newline='\n') + self.rd = csv.DictReader(self.fp,delimiter='\t') + + def __iter__(self): + mapper = lambda data: NsmcRawData(int(data["id"]),data["document"],int(data["label"])) + return iter(map(mapper,self.rd)) + + def close(self): + if self.need_close: + self.fp.close() + + def __enter__(self): + return self + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + +def readNsmcRawData(file: Union[str, TextIOWrapper], use_tqdm = False, total: int = 0) -> List[NsmcRawData]: + dataset = [] + with NsmcRawDataReader(file) as dataReader: + if use_tqdm and total > 0: + for d in tqdm.tqdm(dataReader, total=total): + dataset.append(d) + else: + for data in dataReader: + dataset.append(data) + return dataset BASE_PATH = "nsmc/nsmc-master" +if __name__ == "__main__": + dataset = [] + raw = readNsmcRawData(f"{BASE_PATH}/ratings.txt", use_tqdm= True, total = 200000) \ No newline at end of file diff --git a/ndata.test.py b/ndata.test.py new file mode 100644 index 0000000..daf876d --- /dev/null +++ b/ndata.test.py @@ -0,0 +1,17 @@ +import unittest +from ndata import * +import io + +class Testing(unittest.TestCase): + def testcase(self): + text = """id\tdocument\tlabel +20\t사랑해요\t1""" + textfile = io.StringIO(text) + datas = readNsmcRawData(textfile) + i = datas[0] + self.assertEqual(i.id,20) + self.assertEqual(i.document,"사랑해요") + self.assertEqual(i.label,1) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file