nip.utils.maths.logit_or_dual

Contents

nip.utils.maths.logit_or_dual#

nip.utils.maths.logit_or_dual(a: Float[Tensor, '... logits'], b: Float[Tensor, '... logits']) Float[Tensor, '... logits'][source]#

Compute the logit OR operation for two input tensors with the log-sum-exp trick.

The logit OR operation is defined as:

\[\max(a, b) + \log(1 + \exp(\min(a, b) - \max(a, b)))\]

where \(\max(a, b)\) is the element-wise maximum of the inputs, and \(\min(a, b)\) is the element-wise minimum of the inputs.

Parameters:
  • a (Float[Tensor, "... logits"]) – The first input tensor.

  • b (Float[Tensor, "... logits"]) – The second input tensor.

Returns:

torch.Tensor – The result of the logit OR operation applied element-wise to the input tensors.