10 examples of 'load_state_dict' in Python

Every line of 'load_state_dict' code snippets is scanned for vulnerabilities by our powerful machine learning engine that combs millions of open source libraries, ensuring your Python code is secure.

All examples are scanned by Snyk Code

By copying the Snyk Code Snippets you agree to
60def load_state_dict(self, state):
61 for key, val in state.items():
62 self.__dict__[key] = val
299def load_state_dict(self, state_dict):
300 self.vis_net.load_state_dict(state_dict[0])
301 self.txt_net.load_state_dict(state_dict[1])
85def load_state_dict(self, state):
86 """Load model state
87 """
88 # load network architecture and params
89 self.network_type = state['network_type']
90 self.board_size = state['board_size']
91 self.num_blocks = state['num_blocks']
92 self.base_chans = state['base_chans']
93 self.net = create_network(self.network_type,
94 self.board_size,
95 self.num_blocks,
96 self.base_chans)
97 self.net.load_state_dict(state['net'])
98
99 # load search params
100 self.simulations = state['simulations']
101 self.search_batch_size = state['search_batch_size']
102 self.exploration_coef = state['exploration_coef']
103 self.exploration_depth = state['exploration_depth']
104 self.exploration_noise_alpha = state['exploration_noise_alpha']
105 self.exploration_noise_scale = state['exploration_noise_scale']
106 self.exploration_temperature = state['exploration_temperature']
107
108 # load random number generator state
109 if 'rng' in state:
110 self.rng.__setstate__(state['rng'])
103def load_state_dict(self, saved):
104 self.m = torch.FloatTensor(saved['m'])
105 self.v = torch.FloatTensor(saved['v'])
106 self.n = torch.FloatTensor(saved['n'])
160def load_state_dict(self, state: Dict) -> None:
161 """Load replaybuf contents
162 """
163 self.examples = state['examples']
164 self.write_idx = state['write_idx']
165 self.fresh_counter = state['fresh_counter']
152def load_state_dict(model: torch.nn.Module, state_dict: Dict, skip_wrong_shape: bool = False):
153 model_state_dict = model.state_dict()
154
155 for key in state_dict:
156 if key in model_state_dict:
157 if model_state_dict[key].shape == state_dict[key].shape:
158 model_state_dict[key] = state_dict[key]
159 elif not skip_wrong_shape:
160 m = (
161 f"Shapes of the '{key}' parameters do not match: "
162 f"{model_state_dict[key].shape} vs {state_dict[key].shape}"
163 )
164 raise Exception(m)
165
166 model.load_state_dict(model_state_dict)
635def set_state_dict(model: torch.nn.Module, state_dict: dict):
636 """Set state dict of a model.
637
638 Also works with ``torch.nn.DataParallel`` models."""
639 try:
640 model.load_state_dict(state_dict)
641 # If self.model was saved as nn.DataParallel then remove 'module.' prefix
642 # in every key
643 except RuntimeError: # TODO: Is it safe to catch all runtime errors here?
644 new_state_dict = OrderedDict()
645 for k, v in state_dict.items():
646 new_state_dict[k.replace('module.', '')] = v
647 model.load_state_dict(new_state_dict)
127def load_my_state_dict(self, state_dict, seq_len):
128 own_state = self.state_dict()
129 #print(own_state.keys())
130 #pdb.set_trace()
131 for name, param in state_dict.items():
132 if name in own_state.keys():
133 if isinstance(param, Parameter):
134 # backwards compatibility for serialized parameters
135 param = param.data
136
137 if name.find('Conv2d_1a_3x3') > -1 and not name.find('bn') > -1:
138 param = param.repeat(1, seq_len, 1, 1)
139 param = param / float(seq_len)
140
141 try:
142 own_state[name].copy_(param)
143 except Exception:
144 raise RuntimeError('While copying the parameter named {}, '
145 'whose dimensions in the model are {} and '
146 'whose dimensions in the checkpoint are {}.'
147 .format(name, own_state[name].size(), param.size()))
10def load_state_dict(self, d):
11 for attr in self._state_attrs():
12 setattr(self, attr, d[attr])
15def load(self, path):
16 self.load_state_dict(t.load(path))

Related snippets