Source code for ls_mlkit.dataset.minist_cifar

import torchvision
import torchvision.transforms as transforms


[docs] def get_transforms(dataset: str, size=None): transform_train = None transform_test = None if dataset == "fashionmnist" or dataset == "mnist": transform_list = [] if size is not None: transform_list.append(transforms.Resize(size)) transform_list.append(transforms.ToTensor()) transform_train = transforms.Compose(transform_list) transform_test = transforms.Compose(transform_list) if dataset == "cifar10": transform_train = transforms.Compose( [ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ] ) transform_test = transforms.Compose( [ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ] ) if dataset == "cifar100": transform_train = transforms.Compose( [ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), ] ) transform_test = transforms.Compose( [ transforms.ToTensor(), transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), ] ) assert transform_test is not None and transform_train is not None, "Error, no dataset %s" % dataset return transform_train, transform_test
[docs] def get_dataset(dataset, root="./data", size=None): transform_train, transform_test = get_transforms(dataset, size) trainset, testset = None, None if dataset == "mnist": trainset = torchvision.datasets.MNIST(root=root, train=True, download=True, transform=transform_train) testset = torchvision.datasets.MNIST(root=root, train=False, download=True, transform=transform_test) if dataset == "fashionmnist": trainset = torchvision.datasets.FashionMNIST(root=root, train=True, download=True, transform=transform_train) testset = torchvision.datasets.FashionMNIST(root=root, train=False, download=True, transform=transform_test) if dataset == "cifar10": trainset = torchvision.datasets.CIFAR10(root=root, train=True, download=True, transform=transform_train) testset = torchvision.datasets.CIFAR10(root=root, train=False, download=True, transform=transform_test) if dataset == "cifar100": trainset = torchvision.datasets.CIFAR100(root=root, train=True, download=True, transform=transform_train) testset = torchvision.datasets.CIFAR100(root=root, train=False, download=True, transform=transform_test) return trainset, testset, testset
[docs] def get_minist_dataset(root="data"): trainset, testset, testset = get_dataset("mnist", root) return trainset, testset, testset
[docs] def get_fashionmnist_dataset(root="data"): trainset, testset, testset = get_dataset("fashionmnist", root) return trainset, testset, testset
[docs] def get_cifar10_dataset(root="data"): trainset, testset, testset = get_dataset("cifar10", root) return trainset, testset, testset
[docs] def get_cifar100_dataset(root="data"): trainset, testset, testset = get_dataset("cifar100", root) return trainset, testset, testset