nip.utils.maths.logit_or_dual#
- nip.utils.maths.logit_or_dual(a: Float[Tensor, '... logits'], b: Float[Tensor, '... logits']) Float[Tensor, '... logits'] [source]#
Compute the logit OR operation for two input tensors with the log-sum-exp trick.
The logit OR operation is defined as:
\[\max(a, b) + \log(1 + \exp(\min(a, b) - \max(a, b)))\]where \(\max(a, b)\) is the element-wise maximum of the inputs, and \(\min(a, b)\) is the element-wise minimum of the inputs.
- Parameters:
a (Float[Tensor, "... logits"]) – The first input tensor.
b (Float[Tensor, "... logits"]) – The second input tensor.
- Returns:
torch.Tensor – The result of the logit OR operation applied element-wise to the input tensors.