refactor: free tagIdConverter
This commit is contained in:
parent
142ad917bc
commit
84761d23be
179
Training.ipynb
179
Training.ipynb
File diff suppressed because one or more lines are too long
@ -5,8 +5,6 @@ from read_data import TagIdConverter
|
|||||||
from preprocessing import readPreporcssedDataAll
|
from preprocessing import readPreporcssedDataAll
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
tagIdConverter = TagIdConverter()
|
|
||||||
|
|
||||||
class DatasetArray(Dataset):
|
class DatasetArray(Dataset):
|
||||||
def __init__(self, data):
|
def __init__(self, data):
|
||||||
self.x = data
|
self.x = data
|
||||||
@ -39,7 +37,7 @@ def wrap_sentence(tokenizer: PreTrainedTokenizer, sentence):
|
|||||||
def wrap_entities(tagIdConverter: TagIdConverter, entities):
|
def wrap_entities(tagIdConverter: TagIdConverter, entities):
|
||||||
return [tagIdConverter.O_id] + entities + [tagIdConverter.O_id]
|
return [tagIdConverter.O_id] + entities + [tagIdConverter.O_id]
|
||||||
|
|
||||||
def make_collate_fn(tokenizer: PreTrainedTokenizer):
|
def make_collate_fn(tokenizer: PreTrainedTokenizer, tagIdConverter: TagIdConverter):
|
||||||
def ret_fn(batch):
|
def ret_fn(batch):
|
||||||
words = [wrap_sentence(tokenizer,item["ids"]) for item in batch]
|
words = [wrap_sentence(tokenizer,item["ids"]) for item in batch]
|
||||||
entities = [wrap_entities(tagIdConverter,item["entity_ids"]) for item in batch]
|
entities = [wrap_entities(tagIdConverter,item["entity_ids"]) for item in batch]
|
||||||
|
Loading…
Reference in New Issue
Block a user