131 | def train(self, x_data): |
132 | x_data, x_data = super().train(x_data, x_data) |
133 | |
134 | x_mu_std = np.vstack((np.median(x_data), np.std(x_data))) |
135 | np.save(self.fullfilepath + 'meanstd_x.npy', x_mu_std) |
136 | x_data -= x_mu_std[0] |
137 | x_data /= x_mu_std[1] |
138 | |
139 | csv_logger = CSVLogger(self.fullfilepath + 'log.csv', append=True, separator=',') |
140 | |
141 | if self.task == 'classification': |
142 | raise RuntimeError('astroNN VAE does not support classification task') |
143 | |
144 | reduce_lr = ReduceLROnPlateau(monitor='loss', factor=0.5, epsilon=self.reduce_lr_epsilon, |
145 | patience=self.reduce_lr_patience, min_lr=self.reduce_lr_min, mode='min', |
146 | verbose=2) |
147 | self.compile() |
148 | self.plot_model() |
149 | |
150 | training_generator = DataGenerator(x_data.shape[1], self.batch_size).generate(x_data) |
151 | |
152 | self.keras_model.fit_generator(generator=training_generator, steps_per_epoch=x_data.shape[0] // self.batch_size, |
153 | epochs=self.max_epochs, max_queue_size=20, verbose=2, workers=os.cpu_count(), |
154 | callbacks=[reduce_lr, csv_logger]) |
155 | |
156 | astronn_model = 'model_weights.h5' |
157 | self.keras_vae.save_weights(self.fullfilepath + astronn_model) |
158 | print(astronn_model + ' saved to {}'.format(self.fullfilepath + astronn_model)) |
159 | |
160 | K.clear_session() |
161 | |
162 | return None |