salad.layers packageΒΆ

SubmodulesΒΆ

salad.layers.association moduleΒΆ

class salad.layers.association.AccuracyΒΆ

Bases: torch.nn.modules.module.Module

forward(input)ΒΆ

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class salad.layers.association.AssociationMatrix(verbose=False)ΒΆ

Bases: torch.nn.modules.module.Module

forward(xs, xt)ΒΆ

xs: (Ns, K, …) xt: (Nt, K, …)

class salad.layers.association.AssociativeLoss(walker_weight=1.0, visit_weight=1.0)ΒΆ

Bases: torch.nn.modules.module.Module

Association Loss for Domain Adaptation

Reference: Associative Domain Adaptation, Hausser et al. (2017)

forward(xs, xt, y)ΒΆ

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class salad.layers.association.AugmentationLoss(aug_loss_func=MSELoss(), use_rampup=True)ΒΆ

Bases: torch.nn.modules.module.Module

Augmentation Loss from https://github.com/Britefury/self-ensemble-visual-domain-adapt

forward()ΒΆ

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class salad.layers.association.ClassBalanceLossΒΆ

Bases: torch.nn.modules.module.Module

Class Balance Loss from https://github.com/Britefury/self-ensemble-visual-domain-adapt

forward(tea_out)ΒΆ

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class salad.layers.association.OTLossΒΆ

Bases: torch.nn.modules.module.Module

forward(xs, ys, xt, yt)ΒΆ

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class salad.layers.association.VisitLossΒΆ

Bases: torch.nn.modules.module.Module

forward(Pt)ΒΆ

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class salad.layers.association.WalkerLossΒΆ

Bases: torch.nn.modules.module.Module

forward(Psts, y)ΒΆ

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class salad.layers.association.WassersteinLossΒΆ

Bases: torch.nn.modules.module.Module

forward(input)ΒΆ

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

salad.layers.base moduleΒΆ

class salad.layers.base.AccuracyScoreΒΆ

Bases: torch.nn.modules.module.Module

forward(y, t)ΒΆ

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class salad.layers.base.KLDivWithLogitsΒΆ

Bases: torch.nn.modules.module.Module

forward(x, y)ΒΆ

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class salad.layers.base.MeanAccuracyScoreΒΆ

Bases: torch.nn.modules.module.Module

forward(y, t)ΒΆ

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class salad.layers.base.WeightedCE(confidence_threshold=0.96837722)ΒΆ

Bases: torch.nn.modules.module.Module

Adapted from Self-Ensembling repository

forward(logits, logits_target)ΒΆ

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

robust_binary_crossentropy(pred, tgt)ΒΆ

salad.layers.coral moduleΒΆ

class salad.layers.coral.AffineInvariantDivergenceΒΆ

Bases: salad.layers.coral.CorrelationDistance

class salad.layers.coral.CoralLossΒΆ

Bases: salad.layers.coral.CorrelationDistance

Deep CORAL loss from paper: https://arxiv.org/pdf/1607.01719.pdf

class salad.layers.coral.CorrelationDistance(distance=<function euclid>)ΒΆ

Bases: torch.nn.modules.module.Module

forward(xs, xt)ΒΆ

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class salad.layers.coral.JeffreyDivergenceΒΆ

Bases: salad.layers.coral.CorrelationDistance

Log Coral Loss

class salad.layers.coral.LogCoralLossΒΆ

Bases: torch.nn.modules.module.Module

forward(xs, xt)ΒΆ

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class salad.layers.coral.SteinDivergenceΒΆ

Bases: salad.layers.coral.CorrelationDistance

Log Coral Loss

salad.layers.da moduleΒΆ

class salad.layers.da.AutoAlign2dΒΆ

Bases: torch.nn.modules.batchnorm.BatchNorm2d

forward()ΒΆ

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class salad.layers.da.FeatureAwareNormalizationΒΆ

Bases: torch.nn.modules.module.Module

salad.layers.funcs moduleΒΆ

salad.layers.funcs.concat(x, z)ΒΆ

Concat 4D tensor with expanded 2D tensor

salad.layers.mat moduleΒΆ

Metrics and Divergences for Correlation Matrices

salad.layers.mat.affineinvariant(A, B)ΒΆ
salad.layers.mat.apply(C, func)ΒΆ
salad.layers.mat.cov(x, eps=1e-05)ΒΆ

Estimate the covariance matrix

salad.layers.mat.euclid(A, B)ΒΆ
salad.layers.mat.getdata(N, d, std)ΒΆ
salad.layers.mat.jeffrey(A, B)ΒΆ
salad.layers.mat.logeuclid(A, B)ΒΆ
salad.layers.mat.riemann(A, B)ΒΆ
salad.layers.mat.stable_logdet(A)ΒΆ

Compute the logarithm of the determinant of matrix in a numerically stable way

salad.layers.mat.stein(A, B)ΒΆ

salad.layers.vat moduleΒΆ

class salad.layers.vat.ConditionalEntropyΒΆ

Bases: torch.nn.modules.module.Module

estimates the conditional cross entropy of the input

$$
rac{1}{n} sum_i sum_c p(y_i = c | x_i) log p(y_i = c | x_i)

$$

By default, will assume that samples are across the first and class probabilities across the second dimension.

forward(input)ΒΆ

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class salad.layers.vat.VATLoss(model, radius=1)ΒΆ

Bases: torch.nn.modules.module.Module

Virtual Adversarial Training Loss function

Reference: TODO

forward(x, p)ΒΆ

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

salad.layers.vat.normalize_perturbation(d)ΒΆ