Source code for ls_mlkit.dataset.iris
import numpy as np
import pandas as pd
import torch
from torch.utils.data import random_split
from torch.utils.data.dataset import Dataset
[docs]
class IrisDataset(Dataset):
def __init__(self, path="./data/Iris.csv", label_col_index=5):
super().__init__()
df = pd.read_csv(path, header=None, index_col=None)
df = df.drop(index=0).reset_index(drop=True)
transform_label_to_integer(df, label_col_index)
x = transform_dataframe_to_tensor(df)
self.x = x[:, 1:-1]
self.y = x[:, -1].to(torch.long)
def __len__(self):
return len(self.y)
def __getitem__(self, item):
return self.x[item], self.y[item]
[docs]
def get_iris_dataset(test_ratio=0.2, **kwargs):
dataset = IrisDataset()
num_samples = len(dataset)
train_size = int((1 - test_ratio) * num_samples)
test_size = num_samples - train_size
train_dataset, test_dataset = random_split(
dataset, [train_size, test_size], generator=torch.Generator().manual_seed(0)
)
print(f"total_size = {len(dataset)}, (train_size, test_size) = {len(train_dataset)}, {len(test_dataset)}")
return train_dataset, test_dataset, test_dataset
if __name__ == "__main__":
pass