refactor: refactor preprocessing.py
This commit is contained in:
parent
84761d23be
commit
a7e447b6d6
@ -7,10 +7,9 @@ import os.path as path
|
|||||||
import tqdm
|
import tqdm
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
PREPROCESSING_BASE_PATH = 'prepro'
|
PRE_BASE_PATH = 'prepro'
|
||||||
converter = TagIdConverter()
|
|
||||||
|
|
||||||
def preprocessing(tokenizer : PreTrainedTokenizer,dataset: List[Sentence]):
|
def preprocessing(tokenizer : PreTrainedTokenizer, converter :TagIdConverter,dataset: List[Sentence]):
|
||||||
ret = []
|
ret = []
|
||||||
for item in tqdm.tqdm(dataset):
|
for item in tqdm.tqdm(dataset):
|
||||||
assert len(item.word) == len(item.detail)
|
assert len(item.word) == len(item.detail)
|
||||||
@ -36,9 +35,9 @@ def readPreprocessedData(path: str):
|
|||||||
return json.load(fp)
|
return json.load(fp)
|
||||||
|
|
||||||
def readPreporcssedDataAll():
|
def readPreporcssedDataAll():
|
||||||
train = readPreprocessedData(path.join(PREPROCESSING_BASE_PATH,"train.json"))
|
train = readPreprocessedData(path.join(PRE_BASE_PATH,"train.json"))
|
||||||
dev = readPreprocessedData(path.join(PREPROCESSING_BASE_PATH,"dev.json"))
|
dev = readPreprocessedData(path.join(PRE_BASE_PATH,"dev.json"))
|
||||||
test = readPreprocessedData(path.join(PREPROCESSING_BASE_PATH,"test.json"))
|
test = readPreprocessedData(path.join(PRE_BASE_PATH,"test.json"))
|
||||||
return train, dev, test
|
return train, dev, test
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@ -48,13 +47,14 @@ if __name__ == "__main__":
|
|||||||
rawTrain, rawDev, rawTest = readKoreanDataAll()
|
rawTrain, rawDev, rawTest = readKoreanDataAll()
|
||||||
print("load tokenzier...")
|
print("load tokenzier...")
|
||||||
tokenizer = BertTokenizer.from_pretrained(PRETAINED_MODEL_NAME)
|
tokenizer = BertTokenizer.from_pretrained(PRETAINED_MODEL_NAME)
|
||||||
|
converter = TagIdConverter()
|
||||||
|
|
||||||
print("process train...")
|
print("process train...")
|
||||||
train = preprocessing(tokenizer,rawTrain)
|
train = preprocessing(tokenizer,converter,rawTrain)
|
||||||
saveObject(path.join(PREPROCESSING_BASE_PATH,"train.json"),train)
|
saveObject(path.join(PRE_BASE_PATH,"train.json"),train)
|
||||||
print("process dev...")
|
print("process dev...")
|
||||||
dev = preprocessing(tokenizer,rawDev)
|
dev = preprocessing(tokenizer,converter,rawDev)
|
||||||
saveObject(path.join(PREPROCESSING_BASE_PATH,"dev.json"),dev)
|
saveObject(path.join(PRE_BASE_PATH,"dev.json"),dev)
|
||||||
print("process test...")
|
print("process test...")
|
||||||
test = preprocessing(tokenizer,rawTest)
|
test = preprocessing(tokenizer,converter,rawTest)
|
||||||
saveObject(path.join(PREPROCESSING_BASE_PATH,"test.json"),test)
|
saveObject(path.join(PRE_BASE_PATH,"test.json"),test)
|
||||||
|
Loading…
Reference in New Issue
Block a user