summaryrefslogtreecommitdiffstats
path: root/Wrappers/Python
diff options
context:
space:
mode:
authorEdoardo Pasca <edo.paskino@gmail.com>2019-02-20 15:05:36 +0000
committerEdoardo Pasca <edo.paskino@gmail.com>2019-02-20 15:05:36 +0000
commit7e0ed0c5fef0382d6b6903d8132fd06a2c4d2967 (patch)
treeb2c68db3b03bbf1567a890d01495fc8fa91e813d /Wrappers/Python
parent5317bf21b45433313907c8f4d6331230c2c8349f (diff)
downloadframework-7e0ed0c5fef0382d6b6903d8132fd06a2c4d2967.tar.gz
framework-7e0ed0c5fef0382d6b6903d8132fd06a2c4d2967.tar.bz2
framework-7e0ed0c5fef0382d6b6903d8132fd06a2c4d2967.tar.xz
framework-7e0ed0c5fef0382d6b6903d8132fd06a2c4d2967.zip
add run method
Diffstat (limited to 'Wrappers/Python')
-rw-r--r--Wrappers/Python/ccpi/optimisation/Algorithms.py6
1 files changed, 6 insertions, 0 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/Algorithms.py b/Wrappers/Python/ccpi/optimisation/Algorithms.py
index 9115e6e..0a5cac6 100644
--- a/Wrappers/Python/ccpi/optimisation/Algorithms.py
+++ b/Wrappers/Python/ccpi/optimisation/Algorithms.py
@@ -84,6 +84,12 @@ class Algorithm(object):
def max_iteration(self, value):
assert isinstance(value, int)
self.__max_iteration = value
+ def run(self, iterations, callback=None):
+ '''run n iterations and update the user with the callback if specified'''
+ self.max_iteration += iterations
+ for _ in self:
+ if callback is not None:
+ callback(self.iteration, self.get_current_loss())
class GradientDescent(Algorithm):
'''Implementation of a simple Gradient Descent algorithm