在 "train.py"的647行,需要 "enumerate(loader)"来遍历取出“batch_idx,(input, target)两个变量。
但是会出现timm/data/transforms.py报错”line 18, in __call__np_img = np.array(pil_img, dtype=np.uint8)“,”TypeError: array() takes 1 positional argument but 2 were given“。
不知道是不是timm的版本库问题或是loader数据失败。其中,timm为0.5.4, "train.py"647行处断点”loader“的数据如下:
result<timm.data.loader.PrefetchLoader object at 0x7f8f79a7d100>
dataset: <timm.data.dataset.ImageDataset object at 0x7f8ffeb20040>
fp16: False
loader: <torch.utils.data.dataloader.DataLoader object at 0x7f9002cd8d30>
mean: tensor([[[[123.6750]], [[116.2800]], [[103.5300]]]], device='cuda:0') //[torch.Size([1, 3, 1, 1])]
mixup_enabled: True
random_erasing: RandomErasing(p=0.25, mode=const, count=(1, 1))
sampler: <torch.utils.data.sampler.RandomSampler object at 0x7f8f79a7d040>
std: tensor([[[[58.3950]],[[57.1200]],[[57.3750]]]], device='cuda:0')//[torch.Size([1, 3, 1, 1])]