all():“有‘假’為False,全‘真’為True”
any():“有‘真’為True,全‘假’為False”
import torch
a = torch.rand([2, 4, 6])
print(a)
idx = torch.tensor([0])
b = torch.index_fill(a, dim=2, index=idx, value=0.)
print(b)
print(a.all() == b.all())
print(a.any() == b.any())
c = a.clone().detach()
c[:, :, 0] = 0
print(c.all() == b.all())
print(c.any() == b.any())
Output:
tensor([[[0.9474, 0.3897, 0.4443, 0.0562, 0.4451, 0.6056],
[0.0585, 0.6470, 0.0693, 0.8196, 0.1546, 0.3637],
[0.5138, 0.0938, 0.0193, 0.4533, 0.8031, 0.9386],
[0.8750, 0.8856, 0.6365, 0.8229, 0.9639, 0.4817]],
[[0.0771, 0.9319, 0.0218, 0.8928, 0.6224, 0.1546],
[0.3917, 0.0770, 0.5662, 0.1670, 0.5407, 0.9702],
[0.8385, 0.5863, 0.4124, 0.1201, 0.1032, 0.4559],
[0.0603, 0.2551, 0.8327, 0.8269, 0.1052, 0.1294]]])
tensor([[[0.0000, 0.3897, 0.4443, 0.0562, 0.4451, 0.6056],
[0.0000, 0.6470, 0.0693, 0.8196, 0.1546, 0.3637],
[0.0000, 0.0938, 0.0193, 0.4533, 0.8031, 0.9386],
[0.0000, 0.8856, 0.6365, 0.8229, 0.9639, 0.4817]],
[[0.0000, 0.9319, 0.0218, 0.8928, 0.6224, 0.1546],
[0.0000, 0.0770, 0.5662, 0.1670, 0.5407, 0.9702],
[0.0000, 0.5863, 0.4124, 0.1201, 0.1032, 0.4559],
[0.0000, 0.2551, 0.8327, 0.8269, 0.1052, 0.1294]]])
.all for a and b tensor(False)
.any for a and b tensor(True)
.all for b and c tensor(True)
.any for b and c tensor(True)