Call metric with args:
(tensor([ 5, 22, 4, 15, 25, 11, 5, 3, 30, 14, 11, 20, 8, 23, 12, 15]), tensor([[-4.2334e-01, 3.2128e-01, -1.6276e-01, -5.2903e-02, -4.4799e-01,
8.1375e-02, -8.1450e-01, -8.3723e-01, -1.5693e+00, 1.8602e-01,
9.6764e-02, -7.7650e-01, 8.9860e-01, -5.6379e-01, -1.9089e-01,
-7.3312e-01, -2.6684e-01, 1.9810e-01, -4.7601e-02, 1.1084e-01,
-3.5942e-01, -8.2719e-01, -2.6982e-01, -4.1683e-01, -6.0756e-02,
-6.5203e-01, 4.8693e-01, 4.0284e-02, -5.9342e-01, -6.3951e-01,
1.1935e-01],
[-2.0829e-01, 2.9670e-01, -1.5979e-01, -1.9088e-01, -6.3610e-01,
-1.1845e-01, -5.2010e-01, -8.1823e-01, -1.0798e+00, 1.4608e-01,
4.3075e-02, -8.3660e-01, 5.8538e-01, -5.8224e-01, -1.6127e-01,
-6.5233e-01, -1.3462e-01, 1.6564e-01, -2.8463e-01, 3.5646e-02,
-2.0126e-01, -8.1365e-01, -3.1208e-01, -4.2726e-01, 3.6077e-02,
-4.0859e-01, 1.2671e-01, 3.2303e-02, -5.0047e-01, -6.9472e-01,
2.6332e-02],
[-4.5164e-01, 1.0030e-01, -4.5018e-01, -9.2652e-02, -4.8574e-01,
1.8752e-01, -5.9322e-01, -8.1456e-01, -1.4288e+00, 1.1793e-01,
5.3347e-02, -8.2513e-01, 6.7262e-01, -5.7353e-01, -1.9667e-01,
-7.2191e-01, -3.4521e-01, -1.4000e-01, -1.1416e-01, 1.3564e-01,
-2.1106e-01, -6.5105e-01, -2.7077e-02, -2.6946e-01, 1.2791e-01,
-5.0441e-01, 3.5882e-01, -1.9064e-01, -4.8390e-01, -7.9666e-01,
-1.8060e-02],
[-7.6481e-01, -4.0101e-04, -3.2567e-01, -4.3850e-01, -8.2028e-01,
-1.1373e-01, -1.3069e+00, -1.0709e+00, -1.8884e+00, -7.4310e-02,
1.0504e-01, -9.8710e-01, 1.3911e+00, -9.7736e-01, -7.1756e-01,
-8.8963e-01, -7.9884e-01, 5.6421e-01, -6.9289e-01, -7.8426e-02,
-4.4967e-01, -1.0660e+00, -6.6833e-01, -5.6144e-01, -4.5779e-01,
-1.0289e+00, 4.8788e-01, 2.6905e-02, -2.4126e-01, -1.0833e+00,
-3.6447e-01],
[-3.5791e-01, 3.5727e-02, -2.6489e-01, -3.0950e-02, -2.8023e-01,
-2.7610e-02, -4.3282e-01, -5.4800e-01, -9.9201e-01, 1.5896e-01,
-8.9530e-02, -6.6279e-01, 4.8260e-01, -4.1675e-01, -2.7672e-01,
-4.3392e-01, -5.3007e-02, -1.1136e-01, 3.9124e-02, 2.8032e-01,
-2.3642e-02, -4.3771e-01, -3.4402e-02, -1.6803e-01, 3.0441e-01,
-3.3988e-01, 1.8902e-01, -1.0924e-01, -3.0610e-01, -5.7769e-01,
9.7626e-02],
[-2.3366e-01, 9.6609e-02, -1.1173e-01, -7.6165e-02, -4.5224e-01,
1.7253e-02, -5.2528e-01, -6.2184e-01, -1.1521e+00, 1.8913e-01,
3.8723e-02, -6.6799e-01, 6.1483e-01, -5.1210e-01, -8.5522e-02,
-5.5686e-01, -1.8134e-01, 9.1830e-02, -7.6381e-02, 1.9169e-01,
-2.3889e-01, -6.2782e-01, -9.4335e-02, -3.1402e-01, 1.8291e-02,
-4.8800e-01, 1.8026e-01, 5.5238e-03, -4.6877e-01, -7.9068e-01,
3.4459e-01],
[-4.3803e-01, -8.9626e-02, -1.6414e-01, -3.3728e-01, -7.3173e-01,
1.8028e-01, -8.8425e-01, -9.1455e-01, -1.4416e+00, 1.1838e-01,
-1.5525e-01, -1.0517e+00, 9.1864e-01, -8.9308e-01, -1.4303e-01,
-9.8013e-01, -5.0859e-01, 3.3850e-01, -8.5587e-01, -1.8356e-01,
-2.9663e-01, -1.0568e+00, -4.2416e-01, -2.7945e-01, -1.3507e-01,
-9.7662e-01, 3.5143e-01, 8.2564e-02, -4.2828e-01, -1.1080e+00,
-2.4122e-01],
[-3.7862e-01, 1.0303e-01, -1.3942e-01, -7.3819e-02, -6.1260e-01,
6.9244e-02, -5.1330e-01, -6.4932e-01, -1.2605e+00, 2.1687e-01,
1.2845e-01, -8.5608e-01, 7.1500e-01, -6.3164e-01, -2.9660e-01,
-6.5688e-01, -4.1604e-01, 5.1267e-02, -3.9499e-01, 1.7832e-01,
-2.4692e-01, -8.2239e-01, -2.5374e-01, -3.2104e-01, -7.5938e-02,
-5.8123e-01, 3.2702e-01, 6.7445e-02, -3.6465e-01, -8.0817e-01,
-1.1762e-02],
[-3.4601e-01, 3.1715e-01, -1.9358e-02, -2.0373e-01, -5.7190e-01,
-1.3176e-01, -8.3668e-01, -8.6941e-01, -1.5289e+00, 2.3381e-01,
2.1028e-01, -9.0012e-01, 1.2456e+00, -7.5695e-01, -5.4301e-01,
-3.4929e-01, -1.5777e-01, 2.1443e-01, -3.5490e-02, -2.8588e-01,
-1.9435e-01, -8.4973e-01, -7.4427e-02, -3.4301e-01, -3.4359e-01,
-1.0014e+00, 4.1498e-01, 3.5629e-02, -6.1123e-01, -9.7774e-01,
-1.0896e-01],
[-2.7988e-01, 2.2090e-01, -3.7169e-02, -3.0769e-01, -5.9117e-01,
1.4772e-01, -8.5652e-01, -8.5170e-01, -1.5827e+00, 2.0144e-01,
2.2082e-02, -8.4365e-01, 1.0827e+00, -6.9410e-01, -6.8745e-01,
-5.0489e-01, -4.6304e-01, 5.5176e-01, -2.7457e-01, -3.6374e-01,
-6.6554e-02, -9.0053e-01, -2.5255e-01, -3.4276e-01, -2.4984e-01,
-1.0166e+00, 6.0744e-01, 2.9452e-02, -6.8286e-01, -1.0930e+00,
1.8820e-01],
[-7.2289e-01, 6.5829e-02, -2.5874e-01, -1.2124e-01, -4.2901e-01,
4.3952e-02, -9.0203e-01, -6.3939e-01, -1.6698e+00, 5.2024e-02,
1.9740e-01, -8.3556e-01, 1.1886e+00, -5.5185e-01, -3.0437e-01,
-9.8564e-01, -5.3967e-01, 4.0762e-01, -4.2439e-01, 8.6231e-02,
-4.7848e-01, -9.7442e-01, -4.4225e-01, -2.5567e-01, -3.3004e-01,
-1.0718e+00, 4.2354e-01, 5.4335e-02, -6.3567e-01, -8.8289e-01,
-1.5797e-01],
[-4.1377e-01, 1.7438e-01, -5.5405e-02, -2.7336e-01, -5.0366e-01,
2.1609e-01, -5.5138e-01, -7.6795e-01, -1.2920e+00, 2.1778e-01,
1.3097e-02, -8.1997e-01, 8.6564e-01, -4.7379e-01, -1.5658e-01,
-7.9150e-01, -3.4465e-01, 2.2048e-01, -2.1311e-01, -6.4284e-02,
-2.7810e-01, -7.8189e-01, -3.6181e-01, -3.8200e-01, -1.9598e-01,
-6.1759e-01, 2.2910e-01, -4.4800e-02, -5.0451e-01, -7.3038e-01,
-6.4414e-02],
[-3.7161e-01, 1.5056e-01, -2.5408e-01, -5.7606e-02, -4.7856e-01,
-1.3512e-01, -6.6447e-01, -5.7490e-01, -1.4241e+00, 1.3914e-01,
2.2788e-01, -7.0752e-01, 7.6762e-01, -5.1654e-01, -2.4347e-01,
-5.1930e-01, -1.3222e-01, 4.4097e-02, -5.8084e-02, 9.3724e-02,
-3.4235e-01, -5.6912e-01, -8.5865e-02, -1.5555e-01, -2.0630e-01,
-5.5908e-01, 2.2383e-01, -8.9230e-02, -3.3783e-01, -6.7599e-01,
1.8993e-02],
[-2.7876e-01, 3.6727e-02, -2.2710e-01, -8.0344e-02, -2.6371e-01,
3.0302e-03, -2.2148e-01, -4.2502e-01, -7.4349e-01, 1.0321e-01,
4.9534e-02, -5.3572e-01, 3.7887e-01, -2.7889e-01, -2.5053e-01,
-4.0012e-01, -2.6487e-01, -1.0883e-01, -9.3237e-02, 1.6346e-01,
-1.4623e-01, -4.0750e-01, -1.0636e-01, -1.9950e-01, 1.5323e-01,
-2.6691e-01, 1.8850e-01, -7.8911e-02, -3.4073e-01, -5.4825e-01,
1.2803e-01],
[-3.7803e-01, 1.4453e-02, -2.5635e-01, -7.4367e-02, -6.4179e-01,
-1.7266e-01, -1.0710e+00, -7.4324e-01, -1.6834e+00, -3.6002e-02,
1.4780e-01, -1.0021e+00, 1.1677e+00, -7.4163e-01, -1.1192e-01,
-7.3385e-01, -2.2223e-01, -1.8710e-01, 2.9915e-02, 3.3439e-01,
-4.6200e-01, -7.5395e-01, -1.9572e-01, 1.3413e-01, 6.0762e-02,
-8.3109e-01, 1.2702e-01, -1.2926e-02, -4.4783e-01, -9.0948e-01,
-6.6723e-02],
[-6.4955e-01, -2.5407e-01, -2.0360e-01, -3.5873e-01, -6.9094e-01,
2.0167e-01, -9.4770e-01, -9.5162e-01, -2.0917e+00, 1.1412e-01,
2.2403e-01, -1.1150e+00, 1.3356e+00, -5.3853e-01, -5.8843e-01,
-9.8410e-01, -7.1797e-01, 3.6223e-01, -3.5720e-01, -1.9642e-01,
-5.0072e-01, -1.1756e+00, -3.0184e-01, -3.7104e-01, -3.3182e-01,
-1.5207e+00, 4.6467e-01, -1.8616e-01, -2.6982e-01, -1.1789e+00,
-1.5017e-01]], grad_fn=<AddmmBackward>))
{}
Traceback (most recent call last):
File "/home/arqwer/0Links/2019-2020/domain_adaptation/DomainAdaptation/example.py", line 37, in <module>
tr.fit(train_gen_s, train_gen_t,
File "/home/arqwer/study/2019-2020/domain_adaptation/DomainAdaptation/trainer.py", line 50, in fit
src_metrics = self.score(src_val_data, metrics)
File "/home/arqwer/study/2019-2020/domain_adaptation/DomainAdaptation/trainer.py", line 61, in score
metric(true_classes, pred_classes)
File "/home/arqwer/0Links/2019-2020/domain_adaptation/DomainAdaptation/example.py", line 31, in __call__
print("metric: ", self.metric(*args, **kwargs))
File "/home/arqwer/study/2019-2020/domain_adaptation/DomainAdaptation/metrics/accuracy.py", line 23, in __call__
self._correct += (y_true == y_predict).sum().item()
File "/home/arqwer/.local/lib/python3.8/site-packages/torch/tensor.py", line 28, in wrapped
return f(*args, **kwargs)
RuntimeError: The size of tensor a (16) must match the size of tensor b (31) at non-singleton dimension 1
Process finished with exit code 1
Either AccuracyScore should be changed to receive (labels, logits), or Trainer should convert logits to predicted labels. Since we want Trainer to be as general as possible, then we should go with first option, and change AccuracyScore interface.