Luisgust commited on
Commit
af1906c
·
verified ·
1 Parent(s): dbe99d0

Update vtoonify/train_vtoonify_d.py

Browse files
Files changed (1) hide show
  1. vtoonify/train_vtoonify_d.py +84 -1
vtoonify/train_vtoonify_d.py CHANGED
@@ -391,6 +391,7 @@ def train(args, generator, discriminator, g_optim, d_optim, g_ema, percept, pars
391
 
392
 
393
 
 
394
  if __name__ == "__main__":
395
 
396
  device = "cuda"
@@ -430,4 +431,86 @@ if __name__ == "__main__":
430
  if not args.pretrain:
431
  generator.encoder.load_state_dict(torch.load(args.encoder_path, map_location=lambda storage, loc: storage)["g_ema"])
432
  # we initialize the fusion modules to map f_G \otimes f_E to f_G.
433
- fo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
391
 
392
 
393
 
394
+
395
  if __name__ == "__main__":
396
 
397
  device = "cuda"
 
431
  if not args.pretrain:
432
  generator.encoder.load_state_dict(torch.load(args.encoder_path, map_location=lambda storage, loc: storage)["g_ema"])
433
  # we initialize the fusion modules to map f_G \otimes f_E to f_G.
434
+ for k in generator.fusion_out:
435
+ k.conv.weight.data *= 0.01
436
+ k.conv.weight[:,0:k.conv.weight.shape[0],1,1].data += torch.eye(k.conv.weight.shape[0]).cuda()
437
+ for k in generator.fusion_skip:
438
+ k.weight.data *= 0.01
439
+ k.weight[:,0:k.weight.shape[0],1,1].data += torch.eye(k.weight.shape[0]).cuda()
440
+
441
+ accumulate(g_ema.encoder, generator.encoder, 0)
442
+ accumulate(g_ema.fusion_out, generator.fusion_out, 0)
443
+ accumulate(g_ema.fusion_skip, generator.fusion_skip, 0)
444
+
445
+ g_parameters = list(generator.encoder.parameters())
446
+ if not args.pretrain:
447
+ g_parameters = g_parameters + list(generator.fusion_out.parameters()) + list(generator.fusion_skip.parameters())
448
+
449
+ g_optim = optim.Adam(
450
+ g_parameters,
451
+ lr=args.lr,
452
+ betas=(0.9, 0.99),
453
+ )
454
+
455
+ if args.distributed:
456
+ generator = nn.parallel.DistributedDataParallel(
457
+ generator,
458
+ device_ids=[args.local_rank],
459
+ output_device=args.local_rank,
460
+ broadcast_buffers=False,
461
+ find_unused_parameters=True,
462
+ )
463
+
464
+ parsingpredictor = BiSeNet(n_classes=19)
465
+ parsingpredictor.load_state_dict(torch.load(args.faceparsing_path, map_location=lambda storage, loc: storage))
466
+ parsingpredictor.to(device).eval()
467
+ requires_grad(parsingpredictor, False)
468
+
469
+ # we apply gaussian blur to the images to avoid flickers caused during downsampling
470
+ down = Downsample(kernel=[1, 3, 3, 1], factor=2).to(device)
471
+ requires_grad(down, False)
472
+
473
+ directions = torch.tensor(np.load(args.direction_path)).to(device)
474
+
475
+ # load style codes of DualStyleGAN
476
+ exstyles = np.load(args.exstyle_path, allow_pickle='TRUE').item()
477
+ if args.local_rank == 0 and not os.path.exists('checkpoint/%s/exstyle_code.npy'%(args.name)):
478
+ np.save('checkpoint/%s/exstyle_code.npy'%(args.name), exstyles, allow_pickle=True)
479
+ styles = []
480
+ with torch.no_grad():
481
+ for stylename in exstyles.keys():
482
+ exstyle = torch.tensor(exstyles[stylename]).to(device)
483
+ exstyle = g_ema.zplus2wplus(exstyle)
484
+ styles += [exstyle]
485
+ styles = torch.cat(styles, dim=0)
486
+
487
+ if not args.pretrain:
488
+ discriminator = ConditionalDiscriminator(256, use_condition=True, style_num = styles.size(0)).to(device)
489
+
490
+ d_optim = optim.Adam(
491
+ discriminator.parameters(),
492
+ lr=args.lr,
493
+ betas=(0.9, 0.99),
494
+ )
495
+
496
+ if args.distributed:
497
+ discriminator = nn.parallel.DistributedDataParallel(
498
+ discriminator,
499
+ device_ids=[args.local_rank],
500
+ output_device=args.local_rank,
501
+ broadcast_buffers=False,
502
+ find_unused_parameters=True,
503
+ )
504
+
505
+ percept = lpips.PerceptualLoss(model="net-lin", net="vgg", use_gpu=device.startswith("cuda"), gpu_ids=[args.local_rank])
506
+ requires_grad(percept.model.net, False)
507
+
508
+ pspencoder = load_psp_standalone(args.style_encoder_path, device)
509
+
510
+ if args.local_rank == 0:
511
+ print('Load models and data successfully loaded!')
512
+
513
+ if args.pretrain:
514
+ pretrain(args, generator, g_optim, g_ema, parsingpredictor, down, directions, styles, device)
515
+ else:
516
+ train(args, generator, discriminator, g_optim, d_optim, g_ema, percept, parsingpredictor, down, pspencoder, directions, styles, device)