Deploying multiple machine learning models on a single server

Getting to know the problem

In commercial development, many machine learning use cases imply a multi-tenant architecture and require training a separate model for each client and / or user.





As an example, consider forecasting purchases and demand for certain products using machine learning. If you run a chain of retail stores, you can use customer purchase history data and total demand for those products to predict costs and purchase volumes for each store individually.





Most often, in such cases, to deploy models, you write a Flask service and put it in a Docker container. There are many examples of single-model machine learning servers, but when it comes to deploying multiple models, the developer has few options available to solve the problem.





In multi-tenant applications, the number of tenants is not known in advance and can be practically unlimited - at some point you may have only one client, and at another moment you can serve separate models for each user to thousands of users. This is where the limitations of the standard deployment approach begin to emerge:





  • If we deploy a Docker container for each client, then we end up with a very large and expensive application that will be quite difficult to manage.





  • A single container, in the image of which there are all models, does not work for us either, since thousands of models can work on the server, and new models are added at runtime.





Decision

, . , Airflow S3, ML — .





ML — , : -> .





, :





  • Model — , ; SklearnModel, TensorFlowModel, MyCustomModel . .





  • ModelInfoRepository — , userid -> modelid. , SQAlchemyModelInfoRepository.





  • ModelRepository — , ID. FileSystemRepository, S3Repository .





from abc import ABC


class Model(ABC):
    @abstractmethod
    def predict(self, data: pd.DataFrame) -> np.ndarray:
        raise NotImplementedError
 

class ModelInfoRepository(ABC):
    @abstractmethod
    def get_model_id_by_user_id(self, user_id: str) -> str:
        raise NotImplementedError
 

class ModelRepository(ABC):
    @abstractmethod
    def get_model(self, model_id: str) -> Model:
        raise NotImplementedError
      
      



, sklearn, Amazon S3 userid -> modelid, .





class SklearnModel(Model):
    def __init__(self, model):
        self.model = model
 

    def predict(self, data: pd.DataFrame):
        return self.model.predict(data)
 

class SQAlchemyModelInfoRepository(ModelInfoRepository):
    def __init__(self, sqalchemy_session: Session):
        self.session = sqalchemy_session
 

    def get_model_id_by_user_id(user_id: str) -> str:
        # implementation goes here, query a table in any Database

      
class S3ModelRepository(ModelRepository):
    def __init__(self, s3_client):
        self.s3_client = s3_client
 

    def get_model(self, model_id: str) -> Model:
        # load and deserialize pickle from S3, implementation goes here
      
      



:





def make_app(model_info_repository: ModelInfoRepository,
    				 model_repsitory: ModelRepository) -> Flask:
    app = Flask("multi-model-server")
    
    @app.predict("/predict/<user_id>")
    def predict(user_id):
        model_id = model_info_repository.get_model_id_by_user_id(user_id)
 
        model = model_repsitory.get_model(model_id)
 
        data = pd.DataFrame(request.json())
 
        predictions = model.predict(data)
 
        return jsonify(predictions.tolist())
 
    return app
      
      



, Flask ; sklearn TensorFlow S3 , Flask .





, , . , . cachetools:





from cachetools import Cache
 
class CachedModelRepository(ModelRepository):
    def __init__(self, model_repository: ModelRepository, cache: Cache):
        self.model_repository = model_repository
        self.cache = cache
 
    @abstractmethod
    def get_model(self, model_id: str) -> Model:
        if model_id not in self.cache:
            self.cache[model_id] = self.model_repository.get_model(model_id)
        return self.cache[model_id]
      
      



:





from cachetools import LRUCache
 
model_repository = CachedModelRepository(
    S3ModelRepository(s3_client),
    LRUCache(max_size=10)
)
      
      



- , . , , MLOps . . , . №4 Google: , - .








All Articles