Comments (2)
from clims.
Sorry for the late reply.
对于coco14 val 集,在 test 阶段我们花费了约2.5h
在 crf 阶段,我们重新实现了多进程部分并得到 crf_coco.py
。
当 n_jobs=16
,用时约30h
你可以调整参数 --n jobs
寻求进一步提速
crf_coco.py
:
import torch
import numpy as np
import time
import os
import torch.nn.functional as F
import multiprocessing as mp
from multiprocessing import Process
from omegaconf import OmegaConf
import json
import argparse
from tqdm import tqdm
from libs.utils import DenseCRF, PolynomialLR, scores
from main_v2 import get_dataset, makedirs
def process_crf(i, dataset, logit_dir, postprocessor):
image_id, image, gt_label = dataset.__getitem__(i)
filename = os.path.join(logit_dir, image_id + ".npy")
logit = np.load(filename)
_, H, W = image.shape
logit = torch.FloatTensor(logit)[None, ...]
logit = F.interpolate(logit, size=(H, W), mode="bilinear", align_corners=False)
prob = F.softmax(logit, dim=1)[0].numpy()
image = image.astype(np.uint8).transpose(1, 2, 0)
prob = postprocessor(image, prob)
label = np.argmax(prob, axis=0)
return label, gt_label
def crf(dataset, logit_dir, postprocessor, num_workers=4):
print("CRF post-processing")
pbar = tqdm(total=len(dataset), desc="CRF post-processing", ascii=True)
def update(*a):
pbar.update()
pool = mp.Pool(num_workers)
results = []
for i in range(len(dataset)):
results.append(pool.apply_async(process_crf,
args=(i, dataset, logit_dir, postprocessor),
callback=update))
pool.close()
pool.join()
results = [r.get() for r in results]
print("CRF post-processing finished")
# print("Results:", results)
return results
def main(config_path, n_jobs):
# Configuration
CONFIG = OmegaConf.load(config_path)
torch.set_grad_enabled(False)
print("# jobs:", n_jobs)
# Dataset
dataset = get_dataset(CONFIG.DATASET.NAME)(
root=CONFIG.DATASET.ROOT,
split=CONFIG.DATASET.SPLIT.VAL,
ignore_label=CONFIG.DATASET.IGNORE_LABEL,
mean_bgr=(CONFIG.IMAGE.MEAN.B, CONFIG.IMAGE.MEAN.G, CONFIG.IMAGE.MEAN.R),
augment=False,
)
print(dataset)
# CRF post-processor
postprocessor = DenseCRF(
iter_max=CONFIG.CRF.ITER_MAX,
pos_xy_std=CONFIG.CRF.POS_XY_STD,
pos_w=CONFIG.CRF.POS_W,
bi_xy_std=CONFIG.CRF.BI_XY_STD,
bi_rgb_std=CONFIG.CRF.BI_RGB_STD,
bi_w=CONFIG.CRF.BI_W,
)
# Path to logit files
logit_dir = os.path.join(
CONFIG.EXP.OUTPUT_DIR,
"features",
CONFIG.EXP.ID,
CONFIG.MODEL.NAME.lower(),
CONFIG.DATASET.SPLIT.VAL,
"logit",
)
print("Logit src:", logit_dir)
if not os.path.isdir(logit_dir):
print("Logit not found, run first: python main.py test [OPTIONS]")
quit()
# Path to save scores
save_dir = os.path.join(
CONFIG.EXP.OUTPUT_DIR,
"scores",
CONFIG.EXP.ID,
CONFIG.MODEL.NAME.lower(),
CONFIG.DATASET.SPLIT.VAL,
)
makedirs(save_dir)
save_path = os.path.join(save_dir, "scores_crf_coco.json")
print("Score dst:", save_path)
# CRF
results = crf(dataset, logit_dir, postprocessor, num_workers=n_jobs)
# Evaluation
preds, gts = zip(*results)
# Pixel Accuracy, Mean Accuracy, Class IoU, Mean IoU, Freq Weighted IoU
score = scores(gts, preds, n_class=CONFIG.DATASET.N_CLASSES)
print(f'mIoU: {score["Mean IoU"]}')
with open(save_path, "w") as f:
json.dump(score, f, indent=4, sort_keys=True)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("config_path", type=str)
parser.add_argument("--n_jobs", type=int, default=4)
args = parser.parse_args()
main(args.config_path, args.n_jobs)
from clims.
Related Issues (20)
- Training time HOT 2
- About The quality of initial CAMs HOT 5
- The difference between previous version with new version. HOT 3
- about deeplab setting HOT 2
- When will the code of COCO be released? HOT 25
- 请问如何Finetune CLIP模型? HOT 1
- Ran out of input HOT 1
- 是否可提供训练好的权重档作复现? HOT 1
- Error on load_img_name_list function HOT 5
- Undefined Function get_dataset HOT 15
- How to obtain pre-trained baseline CAM HOT 14
- Need Coco baseline scores HOT 3
- Please check the Pascal VOC train_aug. HOT 2
- 读取数据集出现错误 HOT 2
- Creation of sem-seg HOT 2
- Problem Solve
- How to extract background image features HOT 4
- How to train DeepLabV1-R38 ? HOT 1
- irnet on coco HOT 6
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from clims.