refactor: refactor preprocessing.py
This commit is contained in:
parent
84761d23be
commit
a7e447b6d6
@ -7,10 +7,9 @@ import os.path as path
|
||||
import tqdm
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
PREPROCESSING_BASE_PATH = 'prepro'
|
||||
converter = TagIdConverter()
|
||||
PRE_BASE_PATH = 'prepro'
|
||||
|
||||
def preprocessing(tokenizer : PreTrainedTokenizer,dataset: List[Sentence]):
|
||||
def preprocessing(tokenizer : PreTrainedTokenizer, converter :TagIdConverter,dataset: List[Sentence]):
|
||||
ret = []
|
||||
for item in tqdm.tqdm(dataset):
|
||||
assert len(item.word) == len(item.detail)
|
||||
@ -36,9 +35,9 @@ def readPreprocessedData(path: str):
|
||||
return json.load(fp)
|
||||
|
||||
def readPreporcssedDataAll():
|
||||
train = readPreprocessedData(path.join(PREPROCESSING_BASE_PATH,"train.json"))
|
||||
dev = readPreprocessedData(path.join(PREPROCESSING_BASE_PATH,"dev.json"))
|
||||
test = readPreprocessedData(path.join(PREPROCESSING_BASE_PATH,"test.json"))
|
||||
train = readPreprocessedData(path.join(PRE_BASE_PATH,"train.json"))
|
||||
dev = readPreprocessedData(path.join(PRE_BASE_PATH,"dev.json"))
|
||||
test = readPreprocessedData(path.join(PRE_BASE_PATH,"test.json"))
|
||||
return train, dev, test
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -48,13 +47,14 @@ if __name__ == "__main__":
|
||||
rawTrain, rawDev, rawTest = readKoreanDataAll()
|
||||
print("load tokenzier...")
|
||||
tokenizer = BertTokenizer.from_pretrained(PRETAINED_MODEL_NAME)
|
||||
converter = TagIdConverter()
|
||||
|
||||
print("process train...")
|
||||
train = preprocessing(tokenizer,rawTrain)
|
||||
saveObject(path.join(PREPROCESSING_BASE_PATH,"train.json"),train)
|
||||
train = preprocessing(tokenizer,converter,rawTrain)
|
||||
saveObject(path.join(PRE_BASE_PATH,"train.json"),train)
|
||||
print("process dev...")
|
||||
dev = preprocessing(tokenizer,rawDev)
|
||||
saveObject(path.join(PREPROCESSING_BASE_PATH,"dev.json"),dev)
|
||||
dev = preprocessing(tokenizer,converter,rawDev)
|
||||
saveObject(path.join(PRE_BASE_PATH,"dev.json"),dev)
|
||||
print("process test...")
|
||||
test = preprocessing(tokenizer,rawTest)
|
||||
saveObject(path.join(PREPROCESSING_BASE_PATH,"test.json"),test)
|
||||
test = preprocessing(tokenizer,converter,rawTest)
|
||||
saveObject(path.join(PRE_BASE_PATH,"test.json"),test)
|
||||
|
Loading…
Reference in New Issue
Block a user