refactor: refactor preprocessing.py

This commit is contained in:
monoid 2022-02-22 17:26:11 +09:00
parent 84761d23be
commit a7e447b6d6

View File

@ -7,10 +7,9 @@ import os.path as path
import tqdm
from transformers import PreTrainedTokenizer
PREPROCESSING_BASE_PATH = 'prepro'
converter = TagIdConverter()
PRE_BASE_PATH = 'prepro'
def preprocessing(tokenizer : PreTrainedTokenizer,dataset: List[Sentence]):
def preprocessing(tokenizer : PreTrainedTokenizer, converter :TagIdConverter,dataset: List[Sentence]):
ret = []
for item in tqdm.tqdm(dataset):
assert len(item.word) == len(item.detail)
@ -36,9 +35,9 @@ def readPreprocessedData(path: str):
return json.load(fp)
def readPreporcssedDataAll():
train = readPreprocessedData(path.join(PREPROCESSING_BASE_PATH,"train.json"))
dev = readPreprocessedData(path.join(PREPROCESSING_BASE_PATH,"dev.json"))
test = readPreprocessedData(path.join(PREPROCESSING_BASE_PATH,"test.json"))
train = readPreprocessedData(path.join(PRE_BASE_PATH,"train.json"))
dev = readPreprocessedData(path.join(PRE_BASE_PATH,"dev.json"))
test = readPreprocessedData(path.join(PRE_BASE_PATH,"test.json"))
return train, dev, test
if __name__ == "__main__":
@ -48,13 +47,14 @@ if __name__ == "__main__":
rawTrain, rawDev, rawTest = readKoreanDataAll()
print("load tokenzier...")
tokenizer = BertTokenizer.from_pretrained(PRETAINED_MODEL_NAME)
converter = TagIdConverter()
print("process train...")
train = preprocessing(tokenizer,rawTrain)
saveObject(path.join(PREPROCESSING_BASE_PATH,"train.json"),train)
train = preprocessing(tokenizer,converter,rawTrain)
saveObject(path.join(PRE_BASE_PATH,"train.json"),train)
print("process dev...")
dev = preprocessing(tokenizer,rawDev)
saveObject(path.join(PREPROCESSING_BASE_PATH,"dev.json"),dev)
dev = preprocessing(tokenizer,converter,rawDev)
saveObject(path.join(PRE_BASE_PATH,"dev.json"),dev)
print("process test...")
test = preprocessing(tokenizer,rawTest)
saveObject(path.join(PREPROCESSING_BASE_PATH,"test.json"),test)
test = preprocessing(tokenizer,converter,rawTest)
saveObject(path.join(PRE_BASE_PATH,"test.json"),test)