Compare commits

...

3 Commits

Author SHA1 Message Date
d2256b0ee9 feat: use minibatch 2022-02-18 17:32:13 +09:00
8abd1fe418 feat: add groupby_index 2022-02-18 17:29:46 +09:00
2877525eee test(dataset): change assertEquals to assertEqual 2022-02-18 17:28:31 +09:00
4 changed files with 256 additions and 286 deletions

File diff suppressed because one or more lines are too long

View File

@ -3,7 +3,7 @@ import unittest
class Test(unittest.TestCase): class Test(unittest.TestCase):
def test_padding_array(self): def test_padding_array(self):
self.assertEquals(padding_array([[1,2],[3]]),[[1,2],[3,0]]) self.assertEqual(padding_array([[1,2],[3]]),[[1,2],[3,0]])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

35
groupby_index.py Normal file
View File

@ -0,0 +1,35 @@
import itertools
from typing import Tuple, TypeVar, Iterable
T = TypeVar('T')
def groupby_index(iter: Iterable[T],n:int) -> Iterable[Iterable[T]]:
"""group list by index
Args:
iter (Iterable[T]): iterator to group by index
n (int): The size of groups
Returns:
Iterable[Iterable[T]]: iterable object to group by index
>>> [*map(lambda x:[*x],groupby_index([1,2,3,4],2))]
[[1, 2], [3, 4]]
"""
def keyfunc(x: Tuple[int,T]) -> int:
k, _ = x
return (k // n)
def mapper(x: Tuple[int, Tuple[int, T]]):
_, v = x
return map(lambda y: y[1],v)
g = itertools.groupby(enumerate(iter), keyfunc)
return map(mapper,g)
if __name__ == "__main__":
test_list = [*range(20,-10,-1)]
for g in groupby_index(test_list,4):
print([*g])
print([*map(lambda x:[*x],groupby_index([1,2,3,4],2))])
for g in groupby_index([1,2,3,4],2):
print([*g])

9
groupby_index.test.py Normal file
View File

@ -0,0 +1,9 @@
import unittest
from groupby_index import *
class Test(unittest.TestCase):
def test_padding_array(self):
self.assertEqual([*map(lambda x:[*x],groupby_index([1,2,3,4],2))],[[1,2],[3,4]])
if __name__ == '__main__':
unittest.main()