import os from importlib import import_module import torch import torch.nn as nn import math from torch.autograd import Variable class Model(nn.Module): def __init__(self, args, ckp): super(Model, self).__init__() print('Making model...') self.scale = args.scale self.idx_scale = 0 self.self_ensemble = args.self_ensemble self.chop = args.chop self.precision = args.precision self.cpu = args.cpu self.device = torch.device('cpu' if args.cpu else 'cuda') self.n_GPUs = args.n_GPUs self.save_models = args.save_models module = import_module('model.' + args.model.lower()) self.model = module.make_model(args).to(self.device) if args.precision == 'half': self.model.half() if not args.cpu and args.n_GPUs > 1: self.model = nn.DataParallel(self.model, range(args.n_GPUs)) self.load( ckp.dir, pre_train=args.pre_train, resume=args.resume, cpu=args.cpu ) print(self.model, file=ckp.log_file) def forward(self, x, scale, pos_mat): self.scale = scale target = self.get_model() if hasattr(target, 'set_scale'): target.set_scale(scale) if self.self_ensemble and not self.training: if self.chop: forward_function = self.forward_chop else: forward_function = self.model.forward return self.forward_x8(x, forward_function) elif self.chop and not self.training: return self.forward_chop(x,pos_mat) else: return self.model(x,pos_mat) def get_model(self): if self.n_GPUs <= 1 or self.cpu: return self.model else: return self.model.module def state_dict(self, **kwargs): target = self.get_model() return target.state_dict(**kwargs) def save(self, apath, epoch, is_best=False): target = self.get_model() torch.save( target.state_dict(), os.path.join(apath, 'model', 'model_latest.pt') ) if is_best: torch.save( target.state_dict(), os.path.join(apath, 'model', 'model_best.pt') ) if self.save_models: torch.save( target.state_dict(), os.path.join(apath, 'model', 'model_{}.pt'.format(epoch)) ) def load(self, apath, pre_train='.', resume=-1, cpu=False): if cpu: kwargs = {'map_location': lambda storage, loc: storage} else: kwargs = {} if resume == -1: self.get_model().load_state_dict( torch.load( os.path.join(apath, 'model', 'model_latest.pt'), **kwargs ), strict=False ) elif resume == 0: if pre_train != '.': print('Loading model from {}'.format(pre_train)) self.get_model().load_state_dict( torch.load(pre_train, **kwargs), strict=False ) print('load_model_mode=1') else: self.get_model().load_state_dict( torch.load( os.path.join(apath, 'model', 'model_{}.pt'.format(resume)), **kwargs ), strict=False ) print('load_model_mode=2') def forward_chop(self, x, pos_mat, shave=10, min_size=160000): scale = self.scale[self.idx_scale] n_GPUs = min(self.n_GPUs, 4) b, c, h, w = x.size() h_half, w_half = h // 2, w // 2 h_size, w_size = h_half + shave, w_half + shave lr_list = [ x[:, :, 0:h_size, 0:w_size], x[:, :, 0:h_size, (w - w_size):w], x[:, :, (h - h_size):h, 0:w_size], x[:, :, (h - h_size):h, (w - w_size):w]] if w_size * h_size < min_size: sr_list = [] for i in range(0, 4, n_GPUs): lr_batch = torch.cat(lr_list[i:(i + n_GPUs)], dim=0) sr_batch = self.model(lr_batch, pos_mat) sr_list.extend(sr_batch.chunk(n_GPUs, dim=0)) else: sr_list = [ self.forward_chop(patch, pos_mat, shave=shave, min_size=min_size) \ for patch in lr_list ] scale = math.ceil(scale) h, w = scale * h, scale * w h_half, w_half = scale * h_half, scale * w_half h_size, w_size = scale * h_size, scale * w_size shave *= scale output = x.new(b, c, h, w) output[:, :, 0:h_half, 0:w_half] \ = sr_list[0][:, :, 0:h_half, 0:w_half] output[:, :, 0:h_half, w_half:w] \ = sr_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size] output[:, :, h_half:h, 0:w_half] \ = sr_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half] output[:, :, h_half:h, w_half:w] \ = sr_list[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size] return output def forward_x8(self, x, forward_function): def _transform(v, op): if self.precision != 'single': v = v.float() v2np = v.data.cpu().numpy() if op == 'v': tfnp = v2np[:, :, :, ::-1].copy() elif op == 'h': tfnp = v2np[:, :, ::-1, :].copy() elif op == 't': tfnp = v2np.transpose((0, 1, 3, 2)).copy() ret = torch.Tensor(tfnp).to(self.device) if self.precision == 'half': ret = ret.half() return ret lr_list = [x] for tf in 'v', 'h', 't': lr_list.extend([_transform(t, tf) for t in lr_list]) sr_list = [forward_function(aug) for aug in lr_list] for i in range(len(sr_list)): if i > 3: sr_list[i] = _transform(sr_list[i], 't') if i % 4 > 1: sr_list[i] = _transform(sr_list[i], 'h') if (i % 4) % 2 == 1: sr_list[i] = _transform(sr_list[i], 'v') output_cat = torch.cat(sr_list, dim=0) output = output_cat.mean(dim=0, keepdim=True) return output