· 6 years ago · Dec 14, 2019, 11:10 AM
1// Copyright 2017 The casbin Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package gormadapter
16
17import (
18 "errors"
19 "runtime"
20
21 "github.com/casbin/casbin/v2/model"
22 "github.com/casbin/casbin/v2/persist"
23 "github.com/jinzhu/gorm"
24 "github.com/lib/pq"
25)
26
27var tablePrefix string
28
29type CasbinRule struct {
30 TablePrefix string `gorm:"-"`
31 PType string `gorm:"size:100"`
32 V0 string `gorm:"size:100"`
33 V1 string `gorm:"size:100"`
34 V2 string `gorm:"size:100"`
35 V3 string `gorm:"size:100"`
36 V4 string `gorm:"size:100"`
37 V5 string `gorm:"size:100"`
38}
39
40type Filter struct {
41 PType []string
42 V0 []string
43 V1 []string
44 V2 []string
45 V3 []string
46 V4 []string
47 V5 []string
48}
49
50func (c *CasbinRule) TableName() string {
51 return c.TablePrefix + "casbin_rule" //as Gorm keeps table names are plural, and we love consistency
52}
53
54// Adapter represents the Gorm adapter for policy storage.
55type Adapter struct {
56 tablePrefix string
57 driverName string
58 dataSourceName string
59 dbSpecified bool
60 db *gorm.DB
61 isFiltered bool
62}
63
64// finalizer is the destructor for Adapter.
65func finalizer(a *Adapter) {
66 err := a.db.Close()
67 if err != nil {
68 panic(err)
69 }
70}
71
72// NewAdapter is the constructor for Adapter.
73// dbSpecified is an optional bool parameter. The default value is false.
74// It's up to whether you have specified an existing DB in dataSourceName.
75// If dbSpecified == true, you need to make sure the DB in dataSourceName exists.
76// If dbSpecified == false, the adapter will automatically create a DB named "casbin".
77func NewAdapter(driverName string, dataSourceName string, dbSpecified ...bool) (*Adapter, error) {
78 a := &Adapter{}
79 a.driverName = driverName
80 a.dataSourceName = dataSourceName
81
82 if len(dbSpecified) == 0 {
83 a.dbSpecified = false
84 } else if len(dbSpecified) == 1 {
85 a.dbSpecified = dbSpecified[0]
86 } else {
87 return nil, errors.New("invalid parameter: dbSpecified")
88 }
89
90 // Open the DB, create it if not existed.
91 err := a.open()
92 if err != nil {
93 return nil, err
94 }
95
96 // Call the destructor when the object is released.
97 runtime.SetFinalizer(a, finalizer)
98
99 return a, nil
100}
101
102// NewAdapterByDB obtained through an existing Gorm instance get a adapter, specify the table prefix
103// Example: gormadapter.NewAdapterByDBUsePrefix(&db, "cms_") Automatically generate table name like this "cms_casbin_rule"
104func NewAdapterByDBUsePrefix(db *gorm.DB, prefix string) (*Adapter, error) {
105 a := &Adapter{
106 tablePrefix: prefix,
107 db: db,
108 }
109
110 tablePrefix = prefix
111
112 err := a.createTable()
113 if err != nil {
114 return nil, err
115 }
116
117 return a, nil
118}
119
120func NewAdapterByDB(db *gorm.DB) (*Adapter, error) {
121 a := &Adapter{
122 db: db,
123 }
124
125 err := a.createTable()
126 if err != nil {
127 return nil, err
128 }
129
130 return a, nil
131}
132
133func (a *Adapter) createDatabase() error {
134 var err error
135 var db *gorm.DB
136 if a.driverName == "postgres" {
137 db, err = gorm.Open(a.driverName, a.dataSourceName+" dbname=postgres")
138 } else {
139 db, err = gorm.Open(a.driverName, a.dataSourceName)
140 }
141 if err != nil {
142 return err
143 }
144
145 if a.driverName == "postgres" {
146 if err = db.Exec("CREATE DATABASE casbin").Error; err != nil {
147 // 42P04 is duplicate_database
148 if err.(*pq.Error).Code == "42P04" {
149 db.Close()
150 return nil
151 }
152 }
153 } else if a.driverName != "sqlite3" {
154 err = db.Exec("CREATE DATABASE IF NOT EXISTS casbin").Error
155 }
156 if err != nil {
157 db.Close()
158 return err
159 }
160
161 return db.Close()
162}
163
164func (a *Adapter) open() error {
165 var err error
166 var db *gorm.DB
167
168 if a.dbSpecified {
169 db, err = gorm.Open(a.driverName, a.dataSourceName)
170 if err != nil {
171 return err
172 }
173 } else {
174 if err = a.createDatabase(); err != nil {
175 return err
176 }
177
178 if a.driverName == "postgres" {
179 db, err = gorm.Open(a.driverName, a.dataSourceName+" dbname=casbin")
180 } else if a.driverName == "sqlite3" {
181 db, err = gorm.Open(a.driverName, a.dataSourceName)
182 } else {
183 db, err = gorm.Open(a.driverName, a.dataSourceName+"casbin")
184 }
185 if err != nil {
186 return err
187 }
188 }
189
190 a.db = db
191
192 return a.createTable()
193}
194
195func (a *Adapter) close() error {
196 err := a.db.Close()
197 if err != nil {
198 return err
199 }
200
201 a.db = nil
202 return nil
203}
204
205// getTableInstance return the dynamic table name
206func (a *Adapter) getTableInstance() *CasbinRule {
207 return &CasbinRule{TablePrefix: a.tablePrefix}
208}
209
210func (a *Adapter) createTable() error {
211 if a.db.HasTable(a.getTableInstance()) {
212 return nil
213 }
214
215 return a.db.CreateTable(a.getTableInstance()).Error
216}
217
218func (a *Adapter) dropTable() error {
219 return a.db.DropTable(a.getTableInstance()).Error
220}
221
222func loadPolicyLine(line CasbinRule, model model.Model) {
223 lineText := line.PType
224 if line.V0 != "" {
225 lineText += ", " + line.V0
226 }
227 if line.V1 != "" {
228 lineText += ", " + line.V1
229 }
230 if line.V2 != "" {
231 lineText += ", " + line.V2
232 }
233 if line.V3 != "" {
234 lineText += ", " + line.V3
235 }
236 if line.V4 != "" {
237 lineText += ", " + line.V4
238 }
239 if line.V5 != "" {
240 lineText += ", " + line.V5
241 }
242
243 persist.LoadPolicyLine(lineText, model)
244}
245
246// LoadPolicy loads policy from database.
247func (a *Adapter) LoadPolicy(model model.Model) error {
248 var lines []CasbinRule
249 if err := a.db.Table(a.tablePrefix + "casbin_rule").Find(&lines).Error; err != nil {
250 return err
251 }
252
253 for _, line := range lines {
254 loadPolicyLine(line, model)
255 }
256
257 return nil
258}
259
260// LoadFilteredPolicy loads only policy rules that match the filter.
261func (a *Adapter) LoadFilteredPolicy(model model.Model, filter interface{}) error {
262 var lines []CasbinRule
263
264 filterValue, ok := filter.(Filter)
265 if !ok {
266 return errors.New("invalid filter type")
267 }
268
269 if err := a.db.Scopes(a.filterQuery(a.db, filterValue)).Find(&lines).Error; err != nil {
270 return err
271 }
272
273 for _, line := range lines {
274 loadPolicyLine(line, model)
275 }
276 a.isFiltered = true
277
278 return nil
279}
280
281// IsFiltered returns true if the loaded policy has been filtered.
282func (a *Adapter) IsFiltered() bool {
283 return a.isFiltered
284}
285
286// filterQuery builds the gorm query to match the rule filter to use within a scope.
287func (a *Adapter) filterQuery(db *gorm.DB, filter Filter) func(db *gorm.DB) *gorm.DB {
288 return func(db *gorm.DB) *gorm.DB {
289 if len(filter.PType) > 0 {
290 db = db.Where("p_type in (?)", filter.PType)
291 }
292 if len(filter.V0) > 0 {
293 db = db.Where("v0 in (?)", filter.V0)
294 }
295 if len(filter.V1) > 0 {
296 db = db.Where("v1 in (?)", filter.V1)
297 }
298 if len(filter.V2) > 0 {
299 db = db.Where("v2 in (?)", filter.V2)
300 }
301 if len(filter.V3) > 0 {
302 db = db.Where("v3 in (?)", filter.V3)
303 }
304 if len(filter.V4) > 0 {
305 db = db.Where("v4 in (?)", filter.V4)
306 }
307 if len(filter.V5) > 0 {
308 db = db.Where("v5 in (?)", filter.V5)
309 }
310 return db
311 }
312}
313
314func (a *Adapter) savePolicyLine(ptype string, rule []string) CasbinRule {
315 line := a.getTableInstance()
316
317 line.PType = ptype
318 if len(rule) > 0 {
319 line.V0 = rule[0]
320 }
321 if len(rule) > 1 {
322 line.V1 = rule[1]
323 }
324 if len(rule) > 2 {
325 line.V2 = rule[2]
326 }
327 if len(rule) > 3 {
328 line.V3 = rule[3]
329 }
330 if len(rule) > 4 {
331 line.V4 = rule[4]
332 }
333 if len(rule) > 5 {
334 line.V5 = rule[5]
335 }
336
337 return *line
338}
339
340// SavePolicy saves policy to database.
341func (a *Adapter) SavePolicy(model model.Model) error {
342 err := a.dropTable()
343 if err != nil {
344 return err
345 }
346 err = a.createTable()
347 if err != nil {
348 return err
349 }
350
351 for ptype, ast := range model["p"] {
352 for _, rule := range ast.Policy {
353 line := a.savePolicyLine(ptype, rule)
354 err := a.db.Create(&line).Error
355 if err != nil {
356 return err
357 }
358 }
359 }
360
361 for ptype, ast := range model["g"] {
362 for _, rule := range ast.Policy {
363 line := a.savePolicyLine(ptype, rule)
364 err := a.db.Create(&line).Error
365 if err != nil {
366 return err
367 }
368 }
369 }
370
371 return nil
372}
373
374// AddPolicy adds a policy rule to the storage.
375func (a *Adapter) AddPolicy(sec string, ptype string, rule []string) error {
376 line := a.savePolicyLine(ptype, rule)
377 err := a.db.Create(&line).Error
378 return err
379}
380
381// RemovePolicy removes a policy rule from the storage.
382func (a *Adapter) RemovePolicy(sec string, ptype string, rule []string) error {
383 line := a.savePolicyLine(ptype, rule)
384 err := a.rawDelete(a.db, line) //can't use db.Delete as we're not using primary key http://jinzhu.me/gorm/crud.html#delete
385 return err
386}
387
388// RemoveFilteredPolicy removes policy rules that match the filter from the storage.
389func (a *Adapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) error {
390 line := a.getTableInstance()
391
392 line.PType = ptype
393 if fieldIndex <= 0 && 0 < fieldIndex+len(fieldValues) {
394 line.V0 = fieldValues[0-fieldIndex]
395 }
396 if fieldIndex <= 1 && 1 < fieldIndex+len(fieldValues) {
397 line.V1 = fieldValues[1-fieldIndex]
398 }
399 if fieldIndex <= 2 && 2 < fieldIndex+len(fieldValues) {
400 line.V2 = fieldValues[2-fieldIndex]
401 }
402 if fieldIndex <= 3 && 3 < fieldIndex+len(fieldValues) {
403 line.V3 = fieldValues[3-fieldIndex]
404 }
405 if fieldIndex <= 4 && 4 < fieldIndex+len(fieldValues) {
406 line.V4 = fieldValues[4-fieldIndex]
407 }
408 if fieldIndex <= 5 && 5 < fieldIndex+len(fieldValues) {
409 line.V5 = fieldValues[5-fieldIndex]
410 }
411 err := a.rawDelete(a.db, *line)
412 return err
413}
414
415func (a *Adapter) rawDelete(db *gorm.DB, line CasbinRule) error {
416 queryArgs := []interface{}{line.PType}
417
418 queryStr := "p_type = ?"
419 if line.V0 != "" {
420 queryStr += " and v0 = ?"
421 queryArgs = append(queryArgs, line.V0)
422 }
423 if line.V1 != "" {
424 queryStr += " and v1 = ?"
425 queryArgs = append(queryArgs, line.V1)
426 }
427 if line.V2 != "" {
428 queryStr += " and v2 = ?"
429 queryArgs = append(queryArgs, line.V2)
430 }
431 if line.V3 != "" {
432 queryStr += " and v3 = ?"
433 queryArgs = append(queryArgs, line.V3)
434 }
435 if line.V4 != "" {
436 queryStr += " and v4 = ?"
437 queryArgs = append(queryArgs, line.V4)
438 }
439 if line.V5 != "" {
440 queryStr += " and v5 = ?"
441 queryArgs = append(queryArgs, line.V5)
442 }
443 args := append([]interface{}{queryStr}, queryArgs...)
444 err := db.Delete(a.getTableInstance(), args...).Error
445 return err
446}