fix: variable name error
This commit is contained in:
parent
a1f4605d8b
commit
66727770d8
@ -23,7 +23,7 @@ class NsmcDataset(Dataset):
|
||||
def __getitem__(self, idx):
|
||||
return self.x[idx]
|
||||
|
||||
def make_collate_fn(tokenzier: PreTrainedTokenizer):
|
||||
def make_collate_fn(tokenizer: PreTrainedTokenizer):
|
||||
def collate_fn(batch: List[NsmcRawData]):
|
||||
labels = [s.label for s in batch]
|
||||
return tokenizer([s.document for s in batch], return_tensors='pt', padding='longest', truncation=True), torch.tensor(labels)
|
||||
|
Loading…
Reference in New Issue
Block a user