refactor: refactor preprocessing.py

This commit is contained in:
monoid 2022-02-22 17:26:11 +09:00
parent 84761d23be
commit a7e447b6d6

View File

@ -7,10 +7,9 @@ import os.path as path
import tqdm import tqdm
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
PREPROCESSING_BASE_PATH = 'prepro' PRE_BASE_PATH = 'prepro'
converter = TagIdConverter()
def preprocessing(tokenizer : PreTrainedTokenizer,dataset: List[Sentence]): def preprocessing(tokenizer : PreTrainedTokenizer, converter :TagIdConverter,dataset: List[Sentence]):
ret = [] ret = []
for item in tqdm.tqdm(dataset): for item in tqdm.tqdm(dataset):
assert len(item.word) == len(item.detail) assert len(item.word) == len(item.detail)
@ -36,9 +35,9 @@ def readPreprocessedData(path: str):
return json.load(fp) return json.load(fp)
def readPreporcssedDataAll(): def readPreporcssedDataAll():
train = readPreprocessedData(path.join(PREPROCESSING_BASE_PATH,"train.json")) train = readPreprocessedData(path.join(PRE_BASE_PATH,"train.json"))
dev = readPreprocessedData(path.join(PREPROCESSING_BASE_PATH,"dev.json")) dev = readPreprocessedData(path.join(PRE_BASE_PATH,"dev.json"))
test = readPreprocessedData(path.join(PREPROCESSING_BASE_PATH,"test.json")) test = readPreprocessedData(path.join(PRE_BASE_PATH,"test.json"))
return train, dev, test return train, dev, test
if __name__ == "__main__": if __name__ == "__main__":
@ -48,13 +47,14 @@ if __name__ == "__main__":
rawTrain, rawDev, rawTest = readKoreanDataAll() rawTrain, rawDev, rawTest = readKoreanDataAll()
print("load tokenzier...") print("load tokenzier...")
tokenizer = BertTokenizer.from_pretrained(PRETAINED_MODEL_NAME) tokenizer = BertTokenizer.from_pretrained(PRETAINED_MODEL_NAME)
converter = TagIdConverter()
print("process train...") print("process train...")
train = preprocessing(tokenizer,rawTrain) train = preprocessing(tokenizer,converter,rawTrain)
saveObject(path.join(PREPROCESSING_BASE_PATH,"train.json"),train) saveObject(path.join(PRE_BASE_PATH,"train.json"),train)
print("process dev...") print("process dev...")
dev = preprocessing(tokenizer,rawDev) dev = preprocessing(tokenizer,converter,rawDev)
saveObject(path.join(PREPROCESSING_BASE_PATH,"dev.json"),dev) saveObject(path.join(PRE_BASE_PATH,"dev.json"),dev)
print("process test...") print("process test...")
test = preprocessing(tokenizer,rawTest) test = preprocessing(tokenizer,converter,rawTest)
saveObject(path.join(PREPROCESSING_BASE_PATH,"test.json"),test) saveObject(path.join(PRE_BASE_PATH,"test.json"),test)