forked from feliciahsieh/AirBnB_clone_v3
-
Notifications
You must be signed in to change notification settings - Fork 64
/
db_storage.py
executable file
·103 lines (91 loc) · 3.16 KB
/
db_storage.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
#!/usr/bin/python3
"""Database storage engine using SQLAlchemy with a mysql+mysqldb database
connection.
"""
import os
from models.base_model import Base
from models.amenity import Amenity
from models.city import City
from models.place import Place
from models.state import State
from models.review import Review
from models.user import User
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, scoped_session
name2class = {
'Amenity': Amenity,
'City': City,
'Place': Place,
'State': State,
'Review': Review,
'User': User
}
class DBStorage:
"""Database Storage"""
__engine = None
__session = None
def __init__(self):
"""Initializes the object"""
user = os.getenv('HBNB_MYSQL_USER')
passwd = os.getenv('HBNB_MYSQL_PWD')
host = os.getenv('HBNB_MYSQL_HOST')
database = os.getenv('HBNB_MYSQL_DB')
self.__engine = create_engine('mysql+mysqldb://{}:{}@{}/{}'
.format(user, passwd, host, database))
if os.getenv('HBNB_ENV') == 'test':
Base.metadata.drop_all(self.__engine)
def all(self, cls=None):
"""returns a dictionary of all the objects present"""
if not self.__session:
self.reload()
objects = {}
if type(cls) == str:
cls = name2class.get(cls, None)
if cls:
for obj in self.__session.query(cls):
objects[obj.__class__.__name__ + '.' + obj.id] = obj
else:
for cls in name2class.values():
for obj in self.__session.query(cls):
objects[obj.__class__.__name__ + '.' + obj.id] = obj
return objects
def reload(self):
"""reloads objects from the database"""
session_factory = sessionmaker(bind=self.__engine,
expire_on_commit=False)
Base.metadata.create_all(self.__engine)
self.__session = scoped_session(session_factory)
def new(self, obj):
"""creates a new object"""
self.__session.add(obj)
def save(self):
"""saves the current session"""
self.__session.commit()
def delete(self, obj=None):
"""deletes an object"""
if not self.__session:
self.reload()
if obj:
self.__session.delete(obj)
def close(self):
"""Dispose of current session if active"""
self.__session.remove()
def get(self, cls, id):
"""Retrieve an object"""
if cls is not None and type(cls) is str and id is not None and\
type(id) is str and cls in name2class:
cls = name2class[cls]
result = self.__session.query(cls).filter(cls.id == id).first()
return result
else:
return None
def count(self, cls=None):
"""Count number of objects in storage"""
total = 0
if type(cls) == str and cls in name2class:
cls = name2class[cls]
total = self.__session.query(cls).count()
elif cls is None:
for cls in name2class.values():
total += self.__session.query(cls).count()
return total