查找图像通道之间的均值和标准差PyTorch

ch1maera

假设我有一张张量为张量的图像,其尺寸为(B x C x W x H),其中B是批量大小,C是图像中通道的数量,W和H是宽度的高度图片。我正在寻找使用该transforms.Normalize()函数针对C图像通道上数据集的均值和标准差对图像进行归一化的方法,这意味着我想要一个结果张量为1 x C的形式。是否有一种简单的方法这个?

我尝试过torch.view(C, -1).mean(1)torch.view(C, -1).std(1)但出现错误:

view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

编辑

在研究了view()PyTorch的工作原理之后,我知道了为什么我的方法行不通。但是,我仍然不知道如何获取每个通道的均值和标准差。

trsvchn

您只需要以正确的方式重新排列批处理张量:从[B, C, W, H][B, C, W * H]

batch = batch.view(batch.size(0), batch.size(1), -1)

这是有关随机数据的完整用法示例:

码:

import torch
from torch.utils.data import TensorDataset, DataLoader

data = torch.randn(64, 3, 28, 28)
labels = torch.zeros(64, 1)
dataset = TensorDataset(data, labels)
loader = DataLoader(dataset, batch_size=8)

nimages = 0
mean = 0.
std = 0.
for batch, _ in loader:
    # Rearrange batch to be the shape of [B, C, W * H]
    batch = batch.view(batch.size(0), batch.size(1), -1)
    # Update total number of images
    nimages += batch.size(0)
    # Compute mean and std here
    mean += batch.mean(2).sum(0) 
    std += batch.std(2).sum(0)

# Final step
mean /= nimages
std /= nimages

print(mean)
print(std)

输出:

tensor([-0.0029, -0.0022, -0.0036])
tensor([0.9942, 0.9939, 0.9923])

本文收集自互联网,转载请注明来源。

如有侵权,请联系 [email protected] 删除。

编辑于
0

我来说两句

0 条评论
登录 后参与评论

相关文章