feat: train, dev, test
This commit is contained in:
parent
9fcd0786b1
commit
8a1442995b
14
Batch.ipynb
14
Batch.ipynb
@ -3,7 +3,7 @@
|
|||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 1,
|
"execution_count": 1,
|
||||||
"id": "c916dd3b",
|
"id": "5a4a1e30",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
@ -25,7 +25,7 @@
|
|||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 2,
|
"execution_count": 2,
|
||||||
"id": "d5861234",
|
"id": "710cd5b2",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@ -39,7 +39,7 @@
|
|||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 3,
|
||||||
"id": "5accd3a9",
|
"id": "da018ffe",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
@ -68,7 +68,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "d10fcb83",
|
"id": "69f05cf6",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"data를 준비"
|
"data를 준비"
|
||||||
@ -77,7 +77,7 @@
|
|||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 7,
|
"execution_count": 7,
|
||||||
"id": "552fe555",
|
"id": "961edd10",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
@ -114,7 +114,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "1cff8e03",
|
"id": "4178b576",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"간단한 collate function"
|
"간단한 collate function"
|
||||||
@ -123,7 +123,7 @@
|
|||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "89eb64d8",
|
"id": "a5ff0049",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": []
|
"source": []
|
||||||
|
732
Training.ipynb
732
Training.ipynb
File diff suppressed because one or more lines are too long
@ -7,13 +7,15 @@ from ndata import readNsmcRawData, NsmcRawData
|
|||||||
|
|
||||||
def readNsmcDataAll():
|
def readNsmcDataAll():
|
||||||
"""
|
"""
|
||||||
Returns: train, test
|
Returns: train, dev, test
|
||||||
"""
|
"""
|
||||||
print("read train set", file=sys.stderr)
|
print("read train set", file=sys.stderr)
|
||||||
train = readNsmcRawData("nsmc/nsmc-master/ratings_train.txt",use_tqdm=True,total=150_000)
|
train = readNsmcRawData("nsmc/nsmc-master/ratings_train.txt",use_tqdm=True,total=150_000)
|
||||||
print("read test set", file=sys.stderr)
|
print("read test set", file=sys.stderr)
|
||||||
test = readNsmcRawData("nsmc/nsmc-master/ratings_test.txt",use_tqdm=True,total=50_000)
|
testBig = readNsmcRawData("nsmc/nsmc-master/ratings_test.txt",use_tqdm=True,total=50_000)
|
||||||
return NsmcDataset(train),NsmcDataset(test)
|
test = testBig[:30_000]
|
||||||
|
dev = testBig[30_000:]
|
||||||
|
return NsmcDataset(train),NsmcDataset(dev),NsmcDataset(test)
|
||||||
|
|
||||||
class NsmcDataset(Dataset):
|
class NsmcDataset(Dataset):
|
||||||
def __init__(self, data: List[NsmcRawData]):
|
def __init__(self, data: List[NsmcRawData]):
|
||||||
|
Loading…
Reference in New Issue
Block a user