salad.layers packageΒΆ
SubmodulesΒΆ
salad.layers.association moduleΒΆ
-
class
salad.layers.association.
Accuracy
ΒΆ Bases:
torch.nn.modules.module.Module
-
forward
(input)ΒΆ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
salad.layers.association.
AssociationMatrix
(verbose=False)ΒΆ Bases:
torch.nn.modules.module.Module
-
forward
(xs, xt)ΒΆ xs: (Ns, K, β¦) xt: (Nt, K, β¦)
-
-
class
salad.layers.association.
AssociativeLoss
(walker_weight=1.0, visit_weight=1.0)ΒΆ Bases:
torch.nn.modules.module.Module
Association Loss for Domain Adaptation
Reference: Associative Domain Adaptation, Hausser et al. (2017)
-
forward
(xs, xt, y)ΒΆ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
salad.layers.association.
AugmentationLoss
(aug_loss_func=MSELoss(), use_rampup=True)ΒΆ Bases:
torch.nn.modules.module.Module
Augmentation Loss from https://github.com/Britefury/self-ensemble-visual-domain-adapt
-
forward
()ΒΆ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
salad.layers.association.
ClassBalanceLoss
ΒΆ Bases:
torch.nn.modules.module.Module
Class Balance Loss from https://github.com/Britefury/self-ensemble-visual-domain-adapt
-
forward
(tea_out)ΒΆ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
salad.layers.association.
OTLoss
ΒΆ Bases:
torch.nn.modules.module.Module
-
forward
(xs, ys, xt, yt)ΒΆ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
salad.layers.association.
VisitLoss
ΒΆ Bases:
torch.nn.modules.module.Module
-
forward
(Pt)ΒΆ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
salad.layers.association.
WalkerLoss
ΒΆ Bases:
torch.nn.modules.module.Module
-
forward
(Psts, y)ΒΆ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
salad.layers.association.
WassersteinLoss
ΒΆ Bases:
torch.nn.modules.module.Module
-
forward
(input)ΒΆ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
salad.layers.base moduleΒΆ
-
class
salad.layers.base.
AccuracyScore
ΒΆ Bases:
torch.nn.modules.module.Module
-
forward
(y, t)ΒΆ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
salad.layers.base.
KLDivWithLogits
ΒΆ Bases:
torch.nn.modules.module.Module
-
forward
(x, y)ΒΆ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
salad.layers.base.
MeanAccuracyScore
ΒΆ Bases:
torch.nn.modules.module.Module
-
forward
(y, t)ΒΆ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
salad.layers.base.
WeightedCE
(confidence_threshold=0.96837722)ΒΆ Bases:
torch.nn.modules.module.Module
Adapted from Self-Ensembling repository
-
forward
(logits, logits_target)ΒΆ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
robust_binary_crossentropy
(pred, tgt)ΒΆ
-
salad.layers.coral moduleΒΆ
-
class
salad.layers.coral.
AffineInvariantDivergence
ΒΆ
-
class
salad.layers.coral.
CoralLoss
ΒΆ Bases:
salad.layers.coral.CorrelationDistance
Deep CORAL loss from paper: https://arxiv.org/pdf/1607.01719.pdf
-
class
salad.layers.coral.
CorrelationDistance
(distance=<function euclid>)ΒΆ Bases:
torch.nn.modules.module.Module
-
forward
(xs, xt)ΒΆ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
salad.layers.coral.
JeffreyDivergence
ΒΆ Bases:
salad.layers.coral.CorrelationDistance
Log Coral Loss
-
class
salad.layers.coral.
LogCoralLoss
ΒΆ Bases:
torch.nn.modules.module.Module
-
forward
(xs, xt)ΒΆ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
salad.layers.coral.
SteinDivergence
ΒΆ Bases:
salad.layers.coral.CorrelationDistance
Log Coral Loss
salad.layers.da moduleΒΆ
-
class
salad.layers.da.
AutoAlign2d
ΒΆ Bases:
torch.nn.modules.batchnorm.BatchNorm2d
-
forward
()ΒΆ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
salad.layers.da.
FeatureAwareNormalization
ΒΆ Bases:
torch.nn.modules.module.Module
salad.layers.funcs moduleΒΆ
-
salad.layers.funcs.
concat
(x, z)ΒΆ Concat 4D tensor with expanded 2D tensor
salad.layers.mat moduleΒΆ
Metrics and Divergences for Correlation Matrices
-
salad.layers.mat.
affineinvariant
(A, B)ΒΆ
-
salad.layers.mat.
apply
(C, func)ΒΆ
-
salad.layers.mat.
cov
(x, eps=1e-05)ΒΆ Estimate the covariance matrix
-
salad.layers.mat.
euclid
(A, B)ΒΆ
-
salad.layers.mat.
getdata
(N, d, std)ΒΆ
-
salad.layers.mat.
jeffrey
(A, B)ΒΆ
-
salad.layers.mat.
logeuclid
(A, B)ΒΆ
-
salad.layers.mat.
riemann
(A, B)ΒΆ
-
salad.layers.mat.
stable_logdet
(A)ΒΆ Compute the logarithm of the determinant of matrix in a numerically stable way
-
salad.layers.mat.
stein
(A, B)ΒΆ
salad.layers.vat moduleΒΆ
-
class
salad.layers.vat.
ConditionalEntropy
ΒΆ Bases:
torch.nn.modules.module.Module
estimates the conditional cross entropy of the input
$$- rac{1}{n} sum_i sum_c p(y_i = c | x_i) log p(y_i = c | x_i)
$$
By default, will assume that samples are across the first and class probabilities across the second dimension.
-
forward
(input)ΒΆ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
class
salad.layers.vat.
VATLoss
(model, radius=1)ΒΆ Bases:
torch.nn.modules.module.Module
Virtual Adversarial Training Loss function
Reference: TODO
-
forward
(x, p)ΒΆ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
salad.layers.vat.
normalize_perturbation
(d)ΒΆ