From 66727770d845da3676cb56bfbf1949260c2222c4 Mon Sep 17 00:00:00 2001 From: monoid Date: Wed, 23 Feb 2022 20:38:28 +0900 Subject: [PATCH] fix: variable name error --- ndataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ndataset.py b/ndataset.py index c06b7ba..a7469c0 100644 --- a/ndataset.py +++ b/ndataset.py @@ -23,7 +23,7 @@ class NsmcDataset(Dataset): def __getitem__(self, idx): return self.x[idx] -def make_collate_fn(tokenzier: PreTrainedTokenizer): +def make_collate_fn(tokenizer: PreTrainedTokenizer): def collate_fn(batch: List[NsmcRawData]): labels = [s.label for s in batch] return tokenizer([s.document for s in batch], return_tensors='pt', padding='longest', truncation=True), torch.tensor(labels)