10 examples of 'tensorflow load checkpoint' in Python

Every line of 'tensorflow load checkpoint' code snippets is scanned for vulnerabilities by our powerful machine learning engine that combs millions of open source libraries, ensuring your Python code is secure.

All examples are scanned by Snyk Code

By copying the Snyk Code Snippets you agree to
333def load_checkpoint(sess, saver):
334 #ckpt = tf.train.get_checkpoint_state('save')
335 #if ckpt and ckpt.model_checkpoint_path:
336 #saver.restore(sess, tf.train.latest_checkpoint('save'))
337 ckpt = 'pretrain_g'+str(config['PRE_GEN_EPOCH'])+'_d'+str(config['PRE_DIS_EPOCH'])+'.ckpt'
338 saver.restore(sess, './save/' + ckpt)
339 print 'checkpoint {} loaded'.format(ckpt)
340 return
147def load(saver, sess, ckpt_path):
148 '''Load trained weights.
149
150 Args:
151 saver: TensorFlow saver object.
152 sess: TensorFlow session.
153 ckpt_path: path to checkpoint file with parameters.
154 '''
155 ckpt = tf.train.get_checkpoint_state(ckpt_path)
156 if ckpt and ckpt.model_checkpoint_path:
157 # ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
158 # saver.restore(sess, os.path.join(ckpt_path, ckpt_name))
159 saver.restore(sess, ckpt.model_checkpoint_path)
160 # print("Restored model parameters from {}".format(ckpt_name))
161 print("Restored model parameters")
162 return True
163 else:
164 return False
30def load_ckpt(saver, sess, ckpt_dir="train"):
31 """Load checkpoint from the ckpt_dir (if unspecified, this is train dir) and restore it to saver and sess, waiting 10 secs in the case of failure. Also returns checkpoint name."""
32 while True:
33 try:
34 latest_filename = "checkpoint_best" if ckpt_dir=="eval" else None
35 ckpt_dir = os.path.join(FLAGS.log_root, ckpt_dir)
36 ckpt_state = tf.train.get_checkpoint_state(ckpt_dir, latest_filename=latest_filename)
37 tf.logging.info('Loading checkpoint %s', ckpt_state.model_checkpoint_path)
38 saver.restore(sess, ckpt_state.model_checkpoint_path)
39 return ckpt_state.model_checkpoint_path
40 except:
41 tf.logging.info("Failed to load checkpoint from %s. Sleeping for %i secs...", ckpt_dir, 10)
42 time.sleep(10)
48def load_latest_checkpoint(self):
49 self.saver.restore(self.sess, tf.train.latest_checkpoint('checkpoints'))
155def load(self, checkpoint_dir):
156 """
157 To load the checkpoint use to test or pretrain
158 """
159 print("\nReading Checkpoints.....\n\n")
160 model_dir = "%s_%s_%s" % ("espcn", self.image_size,self.scale)# give the model name by label_size
161 checkpoint_dir = os.path.join(checkpoint_dir, model_dir)
162 ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
163
164 # Check the checkpoint is exist
165 if ckpt and ckpt.model_checkpoint_path:
166 ckpt_path = str(ckpt.model_checkpoint_path) # convert the unicode to string
167 self.saver.restore(self.sess, os.path.join(os.getcwd(), ckpt_path))
168 print("\n Checkpoint Loading Success! %s\n\n"% ckpt_path)
169 else:
170 print("\n! Checkpoint Loading Failed \n\n")
211def restore_from_checkpoint(self):
212 """Returns scaffold function to restore parameters from checkpoint."""
213 def scaffold_fn():
214 """Loads pretrained model through scaffold function."""
215 tf.train.init_from_checkpoint(self._checkpoint,
216 {'/': self._checkpoint_prefix,})
217 return tf.train.Scaffold()
218 return scaffold_fn if self._checkpoint else None
9def get_checkpoint(logdir):
10 ''' Get the first checkpoint '''
11 ckpt = tf.train.get_checkpoint_state(logdir)
12 if ckpt:
13 return ckpt.model_checkpoint_path
14 else:
15 print('No checkpoint found')
16 return None
138def load_chkpt(saver, sess, chkptdir):
139 ckpt = tf.train.get_checkpoint_state(chkptdir)
140 if ckpt and ckpt.model_checkpoint_path:
141 ckpt_fn = ckpt.model_checkpoint_path.replace('//', '/')
142 print('[DEBUG] Loading checkpoint from %s' % ckpt_fn)
143 saver.restore(sess, ckpt_fn)
144 else:
145 raise NameError('[ERROR] No checkpoint found at: %s' % chkptdir)
279def restore_variables(checkpoint):
280 if not checkpoint:
281 return tf.no_op("restore_op")
282
283 # Load checkpoints
284 tf.logging.info("Loading %s" % checkpoint)
285 var_list = tf.train.list_variables(checkpoint)
286 reader = tf.train.load_checkpoint(checkpoint)
287 values = {}
288
289 for (name, shape) in var_list:
290 tensor = reader.get_tensor(name)
291 name = name.split(":")[0]
292 values[name] = tensor
293
294 var_list = tf.trainable_variables()
295 ops = []
296
297 for var in var_list:
298 name = var.name.split(":")[0]
299
300 if name in values:
301 tf.logging.info("Restore %s" % var.name)
302 ops.append(tf.assign(var, values[name]))
303
304 return tf.group(*ops, name="restore_op")
221def restore_checkpoint_if_exists(saver, sess, logdir):
222 """Looks for a checkpoint and restores the session from it if found.
223 Args:
224 saver: A tf.train.Saver for restoring the session.
225 sess: A TensorFlow session.
226 logdir: The directory to look for checkpoints in.
227 Returns:
228 True if a checkpoint was found and restored, False otherwise.
229 """
230 checkpoint = tf.train.get_checkpoint_state(logdir)
231 if checkpoint:
232 checkpoint_name = os.path.basename(checkpoint.model_checkpoint_path)
233 full_checkpoint_path = os.path.join(logdir, checkpoint_name)
234 saver.restore(sess, full_checkpoint_path)
235 return True
236 return False

Related snippets