salad.datasets.digits packageΒΆ

Digits datasets used in domain adaptation

SubmodulesΒΆ

salad.datasets.digits.base moduleΒΆ

salad.datasets.digits.mnist moduleΒΆ

class salad.datasets.digits.mnist.MNIST(root, split='train', transform=None, label_transform=None, download=True)ΒΆ

Bases: torchvision.datasets.mnist.MNIST

MNIST Dataset

imagesΒΆ
labelsΒΆ

salad.datasets.digits.openset moduleΒΆ

class salad.datasets.digits.openset.OpenSetDataset(dataset, known, unknown, labels=None)ΒΆ

Bases: object

Dataset wrapper for openset classification

Works with any classification datasets that outputs a tuple (x, y) when calling the getitem method. Given two sets of label for known and unknown classes, maps unknown class labels to zero.

salad.datasets.digits.openset.get_data(train=True, batch_size=128)ΒΆ

salad.datasets.digits.synth moduleΒΆ

class salad.datasets.digits.synth.Synth(root, split='train', transform=None, label_transform=None, download=True)ΒΆ

Bases: salad.datasets.digits.base._BaseDataset

Synthetic images dataset

extract_images_labels(filename)ΒΆ
image_shape = [16, 16, 1]ΒΆ
num_labels = 10ΒΆ
test_file = 'synth_test_32x32.mat?raw=true'ΒΆ
training_file = 'synth_train_32x32.mat?raw=true'ΒΆ
urls = {'https://github.com/domainadaptation/datasets/blob/master/synth/synth_test_32x32.mat?raw=true', 'https://github.com/domainadaptation/datasets/blob/master/synth/synth_train_32x32.mat?raw=true'}ΒΆ
class salad.datasets.digits.synth.SynthSmall(root, split='train', transform=None, label_transform=None, download=True)ΒΆ

Bases: salad.datasets.digits.base._BaseDataset

Synthetic images dataset

extract_images_labels(filename)ΒΆ
image_shape = [16, 16, 1]ΒΆ
num_labels = 10ΒΆ
test_file = 'synth_test_32x32.mat_small?raw=true'ΒΆ
training_file = 'synth_train_32x32_small.mat?raw=true'ΒΆ
urls = {'https://github.com/domainadaptation/datasets/blob/master/synth/synth_test_32x32_small.mat?raw=true', 'https://github.com/domainadaptation/datasets/blob/master/synth/synth_train_32x32_small.mat?raw=true'}ΒΆ

salad.datasets.digits.usps moduleΒΆ

class salad.datasets.digits.usps.USPS(root, split='train', transform=None, label_transform=None, download=True)ΒΆ

Bases: salad.datasets.digits.base._BaseDataset

[USPS](http://statweb.stanford.edu/~tibs/ElemStatLearn/data.html) Dataset.

Parameters:
  • root (string) – Root directory of dataset where processed/training.pt and processed/test.pt exist.
  • train (bool, optional) – If True, creates dataset from training.pt, otherwise from test.pt.
  • download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.

Download USPS dataset from [1] or use the expliclict links [2] for training and [3] for testing. Code for loading the dataset partly adapted from [4] (Apache License 2.0).

References

[1]http://statweb.stanford.edu/~tibs/ElemStatLearn/data.html
[2]Training Dataset http://statweb.stanford.edu/~tibs/ElemStatLearn/datasets/zip.train.gz
[3]Test Dataset http://statweb.stanford.edu/~tibs/ElemStatLearn/datasets/zip.test.gz
[4]https://github.com/haeusser/learning_by_association/blob/master/semisup/tools/usps.py
extract_images_labels(filename)ΒΆ
image_shape = [16, 16, 1]ΒΆ
num_labels = 10ΒΆ
test_file = 'zip.train.gz'ΒΆ
training_file = 'zip.train.gz'ΒΆ
urls = ['http://statweb.stanford.edu/~tibs/ElemStatLearn/datasets/zip.train.gz', 'http://statweb.stanford.edu/~tibs/ElemStatLearn/datasets/zip.test.gz']ΒΆ