diff --git a/matscipy/dislocation.py b/matscipy/dislocation.py index 1565dbb6..0a69c3e4 100644 --- a/matscipy/dislocation.py +++ b/matscipy/dislocation.py @@ -1014,7 +1014,7 @@ def cost_function(pos, dislo, bulk, cylinder_r, elastic_param, def fit_core_position(dislo_image, bulk, elastic_param, hard_core=False, core_radius=10, current_pos=None, bulk_neighbours=None, - origin=(0, 0)): + origin=(0, 0), return_error=False): """ Use `cost_function()` to fit atomic positions to Stroh solution with @@ -1037,6 +1037,9 @@ def fit_core_position(dislo_image, bulk, elastic_param, hard_core=False, `matscipy.neigbbours.neighbour_list('ij', bulk, alat)`. origin: tuple Optionally pass in coordinate origin (x0, y0) + return_error: bool + Optionally return the `scipy.optimize.minimize` error of the fit. + Requires an additional cost_function() call. Returns ------- @@ -1050,12 +1053,19 @@ def fit_core_position(dislo_image, bulk, elastic_param, hard_core=False, dislo_image, bulk, core_radius, elastic_param, hard_core, False, False, bulk_neighbours, origin), method='Powell', options={'xtol': 1e-2, 'ftol': 1e-2}) - return res.x + + error = cost_function(res.x, dislo_image, bulk, core_radius, elastic_param, hard_core, + False, False, bulk_neighbours, origin) + + if return_error: + return res.x, error + else: + return res.x def fit_core_position_images(images, bulk, elastic_param, bulk_neighbours=None, - origin=(0, 0)): + origin=(0, 0), return_errors=False): """ Call fit_core_position() for a list of Atoms objects, e.g. NEB images @@ -1071,24 +1081,41 @@ def fit_core_position_images(images, bulk, elastic_param, as for `fit_core_position()`. origin: tuple Optionally pass in coordinate origin (x0, y0) + return_errors: bool + Optionally return the `scipy.optimize.minimize` list of errors + of the fit. Returns ------- core_positions: array of shape `(len(images), 2)` """ core_positions = [] + + if return_errors: + core_pos_errors = [] + core_position = images[0].info.get('core_position', origin) for dislo in images: dislo_tmp = dislo.copy() - core_position = fit_core_position(dislo_tmp, bulk, elastic_param, + result = fit_core_position(dislo_tmp, bulk, elastic_param, current_pos=dislo.info.get( 'core_position', core_position), bulk_neighbours=bulk_neighbours, - origin=origin) + origin=origin, return_error=return_errors) + + if return_errors: + core_position, error = result + core_pos_errors.append(error) + else: + core_position = result + dislo.info['core_position'] = core_position core_positions.append(core_position) - return np.array(core_positions) + if return_errors: + return np.array(core_positions), np.array(cor_pos_errors) + else: + return np.array(core_positions) def screw_cyl_tetrahedral(alat, C11, C12, C44,