refactor: free tagIdConverter
This commit is contained in:
parent
142ad917bc
commit
84761d23be
177
Training.ipynb
177
Training.ipynb
File diff suppressed because one or more lines are too long
@ -5,8 +5,6 @@ from read_data import TagIdConverter
|
||||
from preprocessing import readPreporcssedDataAll
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
tagIdConverter = TagIdConverter()
|
||||
|
||||
class DatasetArray(Dataset):
|
||||
def __init__(self, data):
|
||||
self.x = data
|
||||
@ -39,7 +37,7 @@ def wrap_sentence(tokenizer: PreTrainedTokenizer, sentence):
|
||||
def wrap_entities(tagIdConverter: TagIdConverter, entities):
|
||||
return [tagIdConverter.O_id] + entities + [tagIdConverter.O_id]
|
||||
|
||||
def make_collate_fn(tokenizer: PreTrainedTokenizer):
|
||||
def make_collate_fn(tokenizer: PreTrainedTokenizer, tagIdConverter: TagIdConverter):
|
||||
def ret_fn(batch):
|
||||
words = [wrap_sentence(tokenizer,item["ids"]) for item in batch]
|
||||
entities = [wrap_entities(tagIdConverter,item["entity_ids"]) for item in batch]
|
||||
|
Loading…
Reference in New Issue
Block a user