Compare commits

..

10 Commits

Author SHA1 Message Date
46f8d08fd7 feat: training english 2022-02-22 23:36:24 +09:00
30cb6b8fe0 feat: tagIdConverter size property 2022-02-22 23:35:59 +09:00
28ddd289b7 feat: support english 2022-02-22 18:59:07 +09:00
54e757c247 fix: print out stderr 2022-02-22 17:45:53 +09:00
a7e447b6d6 refactor: refactor preprocessing.py 2022-02-22 17:26:11 +09:00
84761d23be refactor: free tagIdConverter 2022-02-22 17:23:14 +09:00
142ad917bc feat: get args 2022-02-22 17:20:16 +09:00
609174b089 feat: add eng_tag 2022-02-22 16:47:16 +09:00
883f39d645 feat(read_data): add english data 2022-02-22 16:33:07 +09:00
bb1e0b5c64 ADD: add line sep 2022-02-22 16:31:56 +09:00
10 changed files with 1826815 additions and 144 deletions

1115
EngTraning.ipynb Normal file

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -5,8 +5,6 @@ from read_data import TagIdConverter
from preprocessing import readPreporcssedDataAll
from transformers import PreTrainedTokenizer
tagIdConverter = TagIdConverter()
class DatasetArray(Dataset):
def __init__(self, data):
self.x = data
@ -39,7 +37,7 @@ def wrap_sentence(tokenizer: PreTrainedTokenizer, sentence):
def wrap_entities(tagIdConverter: TagIdConverter, entities):
return [tagIdConverter.O_id] + entities + [tagIdConverter.O_id]
def make_collate_fn(tokenizer: PreTrainedTokenizer):
def make_collate_fn(tokenizer: PreTrainedTokenizer, tagIdConverter: TagIdConverter):
def ret_fn(batch):
words = [wrap_sentence(tokenizer,item["ids"]) for item in batch]
entities = [wrap_entities(tagIdConverter,item["entity_ids"]) for item in batch]

42
eng_tags.json Normal file
View File

@ -0,0 +1,42 @@
[
{
"name": "[PAD]",
"index": 0
},
{
"name": "B-LOC",
"index": 1
},
{
"name": "B-MISC",
"index": 2
},
{
"name": "B-ORG",
"index": 3
},
{
"name": "B-PER",
"index": 4
},
{
"name": "I-LOC",
"index": 5
},
{
"name": "I-MISC",
"index": 6
},
{
"name": "I-ORG",
"index": 7
},
{
"name": "I-PER",
"index": 8
},
{
"name": "O",
"index": 9
}
]

306118
engpre/dev.json Normal file

File diff suppressed because it is too large Load Diff

287552
engpre/test.json Normal file

File diff suppressed because it is too large Load Diff

1231776
engpre/train.json Normal file

File diff suppressed because it is too large Load Diff

View File

@ -29,6 +29,7 @@ if __name__ == "__main__":
test_list = [*range(20,-10,-1)]
for g in groupby_index(test_list,4):
print([*g])
print("===")
print([*map(lambda x:[*x],groupby_index([1,2,3,4],2))])
for g in groupby_index([1,2,3,4],2):
print([*g])

View File

@ -1,16 +1,17 @@
from read_data import TagIdConverter, make_long_namedEntity, readKoreanDataAll, Sentence
from typing import Any, NamedTuple, List, Sequence, TypeVar
import argparse
import os
import sys
from read_data import TagIdConverter, make_long_namedEntity, readEnglishDataAll, readKoreanDataAll, Sentence
from typing import Any, List
import json
import os.path as path
import tqdm
from transformers import PreTrainedTokenizer
PREPROCESSING_BASE_PATH = 'prepro'
converter = TagIdConverter()
PRE_BASE_PATH = 'prepro'
def preprocessing(tokenizer : PreTrainedTokenizer,dataset: List[Sentence]):
def preprocessing(tokenizer : PreTrainedTokenizer, converter :TagIdConverter,dataset: List[Sentence]):
ret = []
for item in tqdm.tqdm(dataset):
assert len(item.word) == len(item.detail)
@ -35,26 +36,44 @@ def readPreprocessedData(path: str):
with open(path,"r", encoding="utf-8") as fp:
return json.load(fp)
def readPreporcssedDataAll():
train = readPreprocessedData(path.join(PREPROCESSING_BASE_PATH,"train.json"))
dev = readPreprocessedData(path.join(PREPROCESSING_BASE_PATH,"dev.json"))
test = readPreprocessedData(path.join(PREPROCESSING_BASE_PATH,"test.json"))
def readPreporcssedDataAll(path = PRE_BASE_PATH):
train = readPreprocessedData(os.path.join(path,"train.json"))
dev = readPreprocessedData(os.path.join(path,"dev.json"))
test = readPreprocessedData(os.path.join(path,"test.json"))
return train, dev, test
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--kind", default="korean")
parser.add_argument("path",default=PRE_BASE_PATH,help="directory path of processed data")
parser.add_argument("--tag", default="tags.json",help="path of tag description")
args = parser.parse_args()
dirPath = args.path
if args.kind == "korean":
rawTrain, rawDev, rawTest = readKoreanDataAll()
elif args.kind == "english":
rawTrain, rawDev, rawTest = readEnglishDataAll()
else:
print("unknown language",file=sys.stderr)
exit(1)
converter = TagIdConverter(args.tag)
os.makedirs(dirPath)
from transformers import BertTokenizer
PRETAINED_MODEL_NAME = 'bert-base-multilingual-cased'
rawTrain, rawDev, rawTest = readKoreanDataAll()
print("load tokenzier...")
print("load tokenzier...",file=sys.stderr)
tokenizer = BertTokenizer.from_pretrained(PRETAINED_MODEL_NAME)
print("process train...")
train = preprocessing(tokenizer,rawTrain)
saveObject(path.join(PREPROCESSING_BASE_PATH,"train.json"),train)
print("process dev...")
dev = preprocessing(tokenizer,rawDev)
saveObject(path.join(PREPROCESSING_BASE_PATH,"dev.json"),dev)
print("process test...")
test = preprocessing(tokenizer,rawTest)
saveObject(path.join(PREPROCESSING_BASE_PATH,"test.json"),test)
print("process train...",file=sys.stderr)
train = preprocessing(tokenizer,converter,rawTrain)
saveObject(path.join(dirPath,"train.json"),train)
print("process dev...",file=sys.stderr)
dev = preprocessing(tokenizer,converter,rawDev)
saveObject(path.join(dirPath,"dev.json"),dev)
print("process test...",file=sys.stderr)
test = preprocessing(tokenizer,converter,rawTest)
saveObject(path.join(dirPath,"test.json"),test)

View File

@ -1,6 +1,9 @@
import enum
from typing import NamedTuple, List, Sequence, TypeVar
from io import TextIOWrapper
import sys
from typing import Iterable, NamedTuple, List, Sequence, TypeVar
import json
import argparse
KoreanBase="[Ko, En] NER, POStag data/국문 NER, POS"
EnglishBase="[Ko, En] NER, POStag data/영문 NER, POS"
@ -25,29 +28,20 @@ class Sentence(NamedTuple):
self.namedEntity.append(namedEntity)
self.detail.append(detail)
T = TypeVar('T')
def readDataList(lst: List[T]):
ret = []
def readDataList(lst: Iterable[str], sep="\t"):
ret:List[str] = []
for l in lst:
if len(l) > 0:
ret.append(l)
else:
l = l.strip()
if l == "":
yield ret
ret.clear()
def readKoreanData(path: str) -> List[Sentence]:
fp = open(path,encoding="utf-8")
data = []
for line in fp.readlines():
line = line.strip()
if line == "":
data.append([])
else:
data.append(line.split("\t"))
fp.close()
# Do not use csv reader.
ret = []
ret.append(l.split(sep))
for lines in readDataList(data):
def readKoreanData(fp: TextIOWrapper) -> List[Sentence]:
ret = []
# NOTE(monoid): Do not use csv reader.
for lines in readDataList(fp):
sentence = Sentence([],[],[],[])
for line in lines:
word_pos:str = line[0]
@ -55,18 +49,41 @@ def readKoreanData(path: str) -> List[Sentence]:
sentence.append(words[0],line[1],line[2],line[3])
ret.append(sentence)
fp.close()
return ret
def readEnglishData(fp: TextIOWrapper) -> List[Sentence]:
ret = []
for lines in readDataList(fp,sep=" "):
if len(lines) == 1 and lines[0][0] == "-DOCSTART-":
continue
sentence = Sentence([],[],[],[])
for line in lines:
sentence.append(line[0],line[1],line[2],line[3])
ret.append(sentence)
return ret
def readKoreanDataAll():
"""
@return train, dev, test tuple
Each entry is structured as follows:
POS,
Return: train, dev, test tuple
"""
dev = readKoreanData(f"{KoreanBase}/dev.txt")
test = readKoreanData(f"{KoreanBase}/test.txt")
train = readKoreanData(f"{KoreanBase}/train.txt")
with open(f"{KoreanBase}/dev.txt", encoding="utf-8") as fp:
dev = readKoreanData(fp)
with open(f"{KoreanBase}/test.txt", encoding="utf-8") as fp:
test = readKoreanData(fp)
with open(f"{KoreanBase}/train.txt", encoding="utf-8") as fp:
train = readKoreanData(fp)
return train, dev, test
def readEnglishDataAll():
with open(f"{EnglishBase}/valid.txt", encoding="utf-8") as fp:
dev = readEnglishData(fp)
with open(f"{EnglishBase}/test.txt", encoding="utf-8") as fp:
test = readEnglishData(fp)
with open(f"{EnglishBase}/train.txt", encoding="utf-8") as fp:
train = readEnglishData(fp)
return train, dev, test
class TagIdConverter:
@ -86,6 +103,12 @@ class TagIdConverter:
@property
def pad_id(self):
return self.vocab["[PAD]"]
@property
def size(self):
return len(self.vocab)
def __len__(self):
return self.size
def convert_ids_to_tokens(self,ids: List[int]):
return [self.ids_to_token[id] for id in ids]
@ -151,29 +174,37 @@ def make_long_namedEntity(a,b,c):
break
return ret
"""
extracts and stores tags set from the given data.
"""
if __name__ == "__main__":
from tqdm import tqdm
t = TagIdConverter()
parser = argparse.ArgumentParser(description="create tags list")
parser.add_argument("--kind","-k",default='korean', help='kind of language: korean or english')
parser.add_argument("--stdout",action='store_true',help='print tags data to stdout')
parser.add_argument("--path",default="tags.json", help="path of tags data")
args = parser.parse_args()
from tqdm import tqdm
if args.kind == "korean" or args.kind == "ko" or args.kind == "kor":
train, dev, test = readEnglishDataAll()
elif args.kind == "english" or args.kind == "en" or args.kind =="eng":
train, dev, test = readKoreanDataAll()
else:
print("unknown language",file=sys.stderr)
exit(1)
vocab = set()
def getTags(lst: List[Sentence]):
for s in tqdm(lst):
for e in s.detail:
vocab.add(e)
print("get tags from train...")
print("get tags from train...",file=sys.stderr)
getTags(train)
print("get tags from dev...")
print("get tags from dev...",file=sys.stderr)
getTags(dev)
print("get tags from test...")
print("get tags from test...",file=sys.stderr)
getTags(test)
print(vocab)
print(vocab,file=sys.stderr)
for v in vocab:
if v == "O":
continue
@ -183,7 +214,6 @@ if __name__ == "__main__":
if not v in vocab:
print("could not found pair " ,v)
vocab.add(v)
tags = [{"name":"[PAD]","index":0}]
i = 1
vocab_list = [*vocab]
@ -191,6 +221,10 @@ if __name__ == "__main__":
for v in vocab_list:
tags.append({"name":v,"index":i})
i += 1
print(tags)
with open("tags.json","w",encoding="utf-8") as fp:
print(tags,file=sys.stderr)
if args.stdout:
json.dump(tags,sys.stdout,ensure_ascii=False, indent=2)
else:
p = args.path
with open(p,"w",encoding="utf-8") as fp:
json.dump(tags,fp,ensure_ascii=False, indent=2)