diff --git a/.gitignore b/.gitignore index c1ca4c20..a5933378 100644 --- a/.gitignore +++ b/.gitignore @@ -1,9 +1,7 @@ /doc/ /dev/ -/lib/ /bin/ /.shards/ /.vscode/ /invidious /sentry -shard.lock diff --git a/.travis.yml b/.travis.yml index f5918bb1..351bca41 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,9 +5,6 @@ jobs: - stage: build language: crystal crystal: latest - before_install: - - shards update - - shards install install: - crystal build --error-on-warnings src/invidious.cr script: diff --git a/README.md b/README.md index be7c5580..9b5dffe5 100644 --- a/README.md +++ b/README.md @@ -124,7 +124,6 @@ $ exit ```bash $ sudo -i -u invidious $ cd invidious -$ shards update && shards install $ crystal build src/invidious.cr --release # test compiled binary $ ./invidious # stop with ctrl c @@ -161,7 +160,6 @@ $ psql invidious kemal < config/sql/nonces.sql $ psql invidious kemal < config/sql/annotations.sql # Setup Invidious -$ shards update && shards install $ crystal build src/invidious.cr --release ``` diff --git a/docker/Dockerfile b/docker/Dockerfile index 45fade57..0576c1c7 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -3,7 +3,8 @@ RUN apk add -u crystal shards libc-dev \ yaml-dev libxml2-dev sqlite-dev sqlite-static zlib-dev openssl-dev WORKDIR /invidious COPY ./shard.yml ./shard.yml -RUN shards update && shards install +COPY ./shard.lock ./shard.lock +COPY ./lib/ ./lib/ COPY ./src/ ./src/ # TODO: .git folder is required for building – this is destructive. # See definition of CURRENT_BRANCH, CURRENT_COMMIT and CURRENT_VERSION. diff --git a/lib/db/.gitignore b/lib/db/.gitignore new file mode 100644 index 00000000..e978b32e --- /dev/null +++ b/lib/db/.gitignore @@ -0,0 +1,10 @@ +/docs/ +/lib/ +/bin/ +/.shards/ +*.dwarf + +# Libraries don't need dependency lock +# Dependencies will be locked in application that uses them +/shard.lock + diff --git a/lib/db/.travis.yml b/lib/db/.travis.yml new file mode 100644 index 00000000..31f20e13 --- /dev/null +++ b/lib/db/.travis.yml @@ -0,0 +1,7 @@ +language: crystal +crystal: + - latest + - nightly +script: + - crystal spec + - crystal tool format --check diff --git a/lib/db/CHANGELOG.md b/lib/db/CHANGELOG.md new file mode 100644 index 00000000..1a1cf15d --- /dev/null +++ b/lib/db/CHANGELOG.md @@ -0,0 +1,93 @@ +## v0.6.0 (2019-08-02) + +* Fix compatibility issues for Crystal 0.30.0. ([#108](https://github.com/crystal-lang/crystal-db/pull/108), thanks @bcardiff) +* Fix `BeginTransaction#transaction` rollback due to protocol error. ([#101](https://github.com/crystal-lang/crystal-db/pull/101), thanks @straight-shoota) +* CI includes Crystal nightly. ([#106](https://github.com/crystal-lang/crystal-db/pull/106), thanks @bcardiff) +* Add the Cassandra driver. ([#94](https://github.com/crystal-lang/crystal-db/pull/94), thanks @kaukas) +* Several docs improvements. ([#99](https://github.com/crystal-lang/crystal-db/pull/99), [#96](https://github.com/crystal-lang/crystal-db/pull/96), [#107](https://github.com/crystal-lang/crystal-db/pull/107), thanks @nickelghost, @greenbigfrog, @MatthiasWinkelmann) + +## v0.5.1 (2018-11-07) + +* Fix `QueryMethods#query_one?` handling no rows. ([#86](https://github.com/crystal-lang/crystal-db/pull/86), thanks @robdavid) +* Documentation improvements. ([#87](https://github.com/crystal-lang/crystal-db/pull/87), [#82](https://github.com/crystal-lang/crystal-db/pull/82), [#76](https://github.com/crystal-lang/crystal-db/pull/76), thanks @wontruefree, @Heaven31415, @vtambourine) + +## v0.5.0 (2017-12-29) + +* Fix compatibility issues for crystal 0.24.0. No changes in the api. + +## v0.4.4 (2017-12-29) + +* Allow query results to be read as named tuples directly (see [#56](https://github.com/crystal-lang/crystal-db/pull/56), thanks @Nephos) +* Fix sqlite samples in documentation (see [#71](https://github.com/crystal-lang/crystal-db/pull/71), thanks @hinrik) + +## v0.4.3 (2017-11-07) + +* Fix connections were not released when building invalid statements. (see [#65](https://github.com/crystal-lang/crystal-db/pull/65), thanks @crisward) +* Fix some exceptions were not deriving from `DB::Error`. (see [#70](https://github.com/crystal-lang/crystal-db/pull/70), thanks @exts) + +## v0.4.2 (2017-04-21) + +* Fix compatibility issues for crystal 0.22.0 + +## v0.4.1 (2017-04-10) + +* Add spec helper for drivers. [#48](https://github.com/crystal-lang/crystal-db/pull/48) +* Add `#query_each`. [#18](https://github.com/crystal-lang/crystal-db/issues/18) +* Fix `#read(T.class)` to deal better with unhandled types. + +## v0.4.0 (2017-03-20) + +* Add `DB.connect` to create non pooled connections +* Add `Database#checkout` to allow explicit checkout/release connection (see #38) +* Fix `Mapping.from_rs` closes the result_set +* Fix `Mapping` works with nilable types (see #40, thanks @RX14) + +## v0.3.3 (2016-12-24) + +* Fix compatibility issues for crystal 0.20.3 + +## v0.3.2 (2016-12-16) + +* Allow connection pool retry logic in `#scalar` queries. + +## v0.3.1 (2016-12-15) + +* Add ConnectionRefused exception to flag issues when opening new connections. + +## v0.3.0 (2016-12-14) + +* Add support for non prepared statements. [#25](https://github.com/crystal-lang/crystal-db/pull/25) + +* Add support for transactions & nested transactions. [#27](https://github.com/crystal-lang/crystal-db/pull/27) + +* Add `Bool` and `Time` to `DB::Any`. + +## v0.2.2 (2016-12-06) + +This release requires crystal 0.20.1 + +* Changed default connection pool size limit is now 0 (unlimited). + +* Fixed allow new connections right away if pool can be increased. + +## ~~v0.2.1 (2016-12-06)~~ [YANKED] + +## v0.2.0 (2016-10-20) + +* Fixed release DB connection if an exception occurs during execution of a query (thanks @ggiraldez) + +## ~~v0.1.1 (2016-09-28)~~ [YANKED] + +This release requires crystal 0.19.2 + +Note: v0.1.1 is yanked since is incompatible with v0.1.0 [more](https://github.com/crystal-lang/crystal-mysql/issues/10). + +* Added connection pool. `DB.open` works with a underlying connection pool. Use `Database#using_connection` to ensure the same connection is been used across multiple statements. [more](https://github.com/crystal-lang/crystal-db/pull/12) + +* Added mappings. JSON/YAML-like mapping macros (thanks @spalladino) [more](https://github.com/crystal-lang/crystal-db/pull/2) + +* Changed require ResultSet implementors to just implement `read`, optionally implementing `read(T.class)`. + +## v0.1.0 (2016-06-24) + +* Initial release diff --git a/lib/db/LICENSE b/lib/db/LICENSE new file mode 100644 index 00000000..ab07ebce --- /dev/null +++ b/lib/db/LICENSE @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2016 Brian J. Cardiff + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/lib/db/README.md b/lib/db/README.md new file mode 100644 index 00000000..fca3379f --- /dev/null +++ b/lib/db/README.md @@ -0,0 +1,97 @@ +[![Build Status](https://travis-ci.org/crystal-lang/crystal-db.svg?branch=master)](https://travis-ci.org/crystal-lang/crystal-db) + +# crystal-db + +Common db api for crystal. You will need to have a specific driver to access a database. + +* [SQLite](https://github.com/crystal-lang/crystal-sqlite3) +* [MySQL](https://github.com/crystal-lang/crystal-mysql) +* [PostgreSQL](https://github.com/will/crystal-pg) +* [Cassandra](https://github.com/kaukas/crystal-cassandra) + +## Installation + +If you are creating a shard that will work with _any_ driver, then add `crystal-db` as a dependency in `shard.yml`: + +```yaml +dependencies: + db: + github: crystal-lang/crystal-db +``` + +If you are creating an application that will work with _some specific_ driver(s), then add them in `shard.yml`: + +```yaml +dependencies: + sqlite3: + github: crystal-lang/crystal-sqlite3 +``` + +`crystal-db` itself will be a nested dependency if drivers are included. + +Note: Multiple drivers can be included in the same application. + +## Documentation + +* [Latest API](http://crystal-lang.github.io/crystal-db/api/latest/) +* [Crystal book](https://crystal-lang.org/docs/database/) + +## Usage + +This shard only provides an abstract database API. In order to use it, a specific driver for the intended database has to be required as well: + +The following example uses SQLite where `?` indicates the arguments. If PostgreSQL is used `$1`, `$2`, etc. should be used. `crystal-db` does not interpret the statements. + +```crystal +require "db" +require "sqlite3" + +DB.open "sqlite3:./file.db" do |db| + # When using the pg driver, use $1, $2, etc. instead of ? + db.exec "create table contacts (name text, age integer)" + db.exec "insert into contacts values (?, ?)", "John Doe", 30 + + args = [] of DB::Any + args << "Sarah" + args << 33 + db.exec "insert into contacts values (?, ?)", args + + puts "max age:" + puts db.scalar "select max(age) from contacts" # => 33 + + puts "contacts:" + db.query "select name, age from contacts order by age desc" do |rs| + puts "#{rs.column_name(0)} (#{rs.column_name(1)})" + # => name (age) + rs.each do + puts "#{rs.read(String)} (#{rs.read(Int32)})" + # => Sarah (33) + # => John Doe (30) + end + end +end +``` + +## Roadmap + +Issues not yet addressed: + +- [x] Support non prepared statements. [#25](https://github.com/crystal-lang/crystal-db/pull/25) +- [x] Time data type. (implementation details depends on actual drivers) +- [x] Data type extensibility. Allow each driver to extend the data types allowed. +- [x] Transactions & nested transactions. [#27](https://github.com/crystal-lang/crystal-db/pull/27) +- [x] Connection pool. +- [ ] Logging +- [ ] Direct access to `IO` to avoid memory allocation for blobs. + +## Contributing + +1. Fork it ( https://github.com/crystal-lang/crystal-db/fork ) +2. Create your feature branch (git checkout -b my-new-feature) +3. Commit your changes (git commit -am 'Add some feature') +4. Push to the branch (git push origin my-new-feature) +5. Create a new Pull Request + +## Contributors + +- [bcardiff](https://github.com/bcardiff) Brian J. Cardiff - creator, maintainer diff --git a/lib/db/shard.yml b/lib/db/shard.yml new file mode 100644 index 00000000..7a4e26a5 --- /dev/null +++ b/lib/db/shard.yml @@ -0,0 +1,9 @@ +name: db +version: 0.6.0 + +authors: + - Brian J. Cardiff + +crystal: 0.24.0 + +license: MIT diff --git a/lib/db/spec/custom_drivers_types_spec.cr b/lib/db/spec/custom_drivers_types_spec.cr new file mode 100644 index 00000000..21d16685 --- /dev/null +++ b/lib/db/spec/custom_drivers_types_spec.cr @@ -0,0 +1,293 @@ +require "./spec_helper" + +module GenericResultSet + @index = 0 + + def move_next : Bool + @index = 0 + true + end + + def column_count : Int32 + @row.size + end + + def column_name(index : Int32) : String + index.to_s + end + + def read + @index += 1 + @row[@index - 1] + end +end + +class FooValue + def initialize(@value : Int32) + end + + def value + @value + end +end + +class FooDriver < DB::Driver + alias Any = DB::Any | FooValue + @@row = [] of Any + + def self.fake_row=(row : Array(Any)) + @@row = row + end + + def self.fake_row + @@row + end + + def build_connection(context : DB::ConnectionContext) : DB::Connection + FooConnection.new(context) + end + + class FooConnection < DB::Connection + def build_prepared_statement(query) : DB::Statement + FooStatement.new(self) + end + + def build_unprepared_statement(query) : DB::Statement + raise "not implemented" + end + end + + class FooStatement < DB::Statement + protected def perform_query(args : Enumerable) : DB::ResultSet + args.each { |arg| process_arg arg } + FooResultSet.new(self, FooDriver.fake_row) + end + + protected def perform_exec(args : Enumerable) : DB::ExecResult + args.each { |arg| process_arg arg } + DB::ExecResult.new 0i64, 0i64 + end + + private def process_arg(value : FooDriver::Any) + end + + private def process_arg(value) + raise "#{self.class} does not support #{value.class} params" + end + end + + class FooResultSet < DB::ResultSet + include GenericResultSet + + def initialize(statement, @row : Array(FooDriver::Any)) + super(statement) + end + end +end + +DB.register_driver "foo", FooDriver + +class BarValue + getter value + + def initialize(@value : Int32) + end +end + +class BarDriver < DB::Driver + alias Any = DB::Any | BarValue + @@row = [] of Any + + def self.fake_row=(row : Array(Any)) + @@row = row + end + + def self.fake_row + @@row + end + + def build_connection(context : DB::ConnectionContext) : DB::Connection + BarConnection.new(context) + end + + class BarConnection < DB::Connection + def build_prepared_statement(query) : DB::Statement + BarStatement.new(self) + end + + def build_unprepared_statement(query) : DB::Statement + raise "not implemented" + end + end + + class BarStatement < DB::Statement + protected def perform_query(args : Enumerable) : DB::ResultSet + args.each { |arg| process_arg arg } + BarResultSet.new(self, BarDriver.fake_row) + end + + protected def perform_exec(args : Enumerable) : DB::ExecResult + args.each { |arg| process_arg arg } + DB::ExecResult.new 0i64, 0i64 + end + + private def process_arg(value : BarDriver::Any) + end + + private def process_arg(value) + raise "#{self.class} does not support #{value.class} params" + end + end + + class BarResultSet < DB::ResultSet + include GenericResultSet + + def initialize(statement, @row : Array(BarDriver::Any)) + super(statement) + end + end +end + +DB.register_driver "bar", BarDriver + +describe DB do + it "should be able to register multiple drivers" do + DB.open("foo://host").driver.should be_a(FooDriver) + DB.open("bar://host").driver.should be_a(BarDriver) + end + + it "Foo and Bar drivers should return fake_row" do + with_witness do |w| + DB.open("foo://host") do |db| + FooDriver.fake_row = [1, "string", FooValue.new(3)] of FooDriver::Any + db.query "query" do |rs| + w.check + rs.move_next + rs.read(Int32).should eq(1) + rs.read(String).should eq("string") + rs.read(FooValue).value.should eq(3) + end + end + end + + with_witness do |w| + DB.open("bar://host") do |db| + BarDriver.fake_row = [BarValue.new(4), "lorem", 1.0] of BarDriver::Any + db.query "query" do |rs| + w.check + rs.move_next + rs.read(BarValue).value.should eq(4) + rs.read(String).should eq("lorem") + rs.read(Float64).should eq(1.0) + end + end + end + end + + it "drivers should return custom values as scalar" do + DB.open("foo://host") do |db| + FooDriver.fake_row = [FooValue.new(3)] of FooDriver::Any + db.scalar("query").as(FooValue).value.should eq(3) + end + end + + it "Foo and Bar drivers should not implement each other read" do + with_witness do |w| + DB.open("foo://host") do |db| + FooDriver.fake_row = [1] of FooDriver::Any + db.query "query" do |rs| + rs.move_next + expect_raises(Exception, "FooResultSet#read returned a Int32. A BarValue was expected.") do + w.check + rs.read(BarValue) + end + end + end + end + + with_witness do |w| + DB.open("bar://host") do |db| + BarDriver.fake_row = [1] of BarDriver::Any + db.query "query" do |rs| + rs.move_next + expect_raises(Exception, "BarResultSet#read returned a Int32. A FooValue was expected.") do + w.check + rs.read(FooValue) + end + end + end + end + end + + it "allow custom types to be used as arguments for query" do + DB.open("foo://host") do |db| + FooDriver.fake_row = [1, "string"] of FooDriver::Any + db.query "query" { } + db.query "query", 1 { } + db.query "query", 1, "string" { } + db.query("query", Bytes.new(4)) { } + db.query("query", 1, "string", FooValue.new(5)) { } + db.query "query", [1, "string", FooValue.new(5)] { } + + db.query("query").close + db.query("query", 1).close + db.query("query", 1, "string").close + db.query("query", Bytes.new(4)).close + db.query("query", 1, "string", FooValue.new(5)).close + db.query("query", [1, "string", FooValue.new(5)]).close + end + + DB.open("bar://host") do |db| + BarDriver.fake_row = [1, "string"] of BarDriver::Any + db.query "query" { } + db.query "query", 1 { } + db.query "query", 1, "string" { } + db.query("query", Bytes.new(4)) { } + db.query("query", 1, "string", BarValue.new(5)) { } + db.query "query", [1, "string", BarValue.new(5)] { } + + db.query("query").close + db.query("query", 1).close + db.query("query", 1, "string").close + db.query("query", Bytes.new(4)).close + db.query("query", 1, "string", BarValue.new(5)).close + db.query("query", [1, "string", BarValue.new(5)]).close + end + end + + it "allow custom types to be used as arguments for exec" do + DB.open("foo://host") do |db| + FooDriver.fake_row = [1, "string"] of FooDriver::Any + db.exec("query") + db.exec("query", 1) + db.exec("query", 1, "string") + db.exec("query", Bytes.new(4)) + db.exec("query", 1, "string", FooValue.new(5)) + db.exec("query", [1, "string", FooValue.new(5)]) + end + + DB.open("bar://host") do |db| + BarDriver.fake_row = [1, "string"] of BarDriver::Any + db.exec("query") + db.exec("query", 1) + db.exec("query", 1, "string") + db.exec("query", Bytes.new(4)) + db.exec("query", 1, "string", BarValue.new(5)) + db.exec("query", [1, "string", BarValue.new(5)]) + end + end + + it "Foo and Bar drivers should not implement each other params" do + DB.open("foo://host") do |db| + expect_raises Exception, "FooDriver::FooStatement does not support BarValue params" do + db.exec("query", [BarValue.new(5)]) + end + end + + DB.open("bar://host") do |db| + expect_raises Exception, "BarDriver::BarStatement does not support FooValue params" do + db.exec("query", [FooValue.new(5)]) + end + end + end +end diff --git a/lib/db/spec/database_spec.cr b/lib/db/spec/database_spec.cr new file mode 100644 index 00000000..e3aeb79b --- /dev/null +++ b/lib/db/spec/database_spec.cr @@ -0,0 +1,255 @@ +require "./spec_helper" + +describe DB::Database do + it "allows connection initialization" do + cnn_setup = 0 + DB.open "dummy://localhost:1027?initial_pool_size=2&max_pool_size=4&max_idle_pool_size=1" do |db| + cnn_setup.should eq(0) + + db.setup_connection do |cnn| + cnn_setup += 1 + end + + cnn_setup.should eq(2) + + db.using_connection do + cnn_setup.should eq(2) + db.using_connection do + cnn_setup.should eq(2) + db.using_connection do + cnn_setup.should eq(3) + db.using_connection do + cnn_setup.should eq(4) + end + # the pool didn't shrink no new initialization should be done next + db.using_connection do + cnn_setup.should eq(4) + end + end + # the pool shrink 1. max_idle_pool_size=1 + # after the previous end there where 2. + db.using_connection do + cnn_setup.should eq(4) + # so now there will be a new connection created + db.using_connection do + cnn_setup.should eq(5) + end + end + end + end + end + end + + it "should allow creation of more statements than pool connections" do + DB.open "dummy://localhost:1027?initial_pool_size=1&max_pool_size=2" do |db| + db.build("query1").should be_a(DB::PoolPreparedStatement) + db.build("query2").should be_a(DB::PoolPreparedStatement) + db.build("query3").should be_a(DB::PoolPreparedStatement) + end + end + + it "should return same statement in pool per query" do + with_dummy do |db| + stmt = db.build("query1") + db.build("query2").should_not eq(stmt) + db.build("query1").should eq(stmt) + end + end + + it "should close pool statements when closing db" do + stmt = uninitialized DB::PoolStatement + with_dummy do |db| + stmt = db.build("query1") + end + stmt.closed?.should be_true + end + + it "should not reconnect if connection is lost and retry_attempts=0" do + DummyDriver::DummyConnection.clear_connections + DB.open "dummy://localhost:1027?initial_pool_size=1&max_pool_size=1&retry_attempts=0" do |db| + db.exec("stmt1") + DummyDriver::DummyConnection.connections.size.should eq(1) + DummyDriver::DummyConnection.connections.first.disconnect! + expect_raises DB::PoolRetryAttemptsExceeded do + db.exec("stmt1") + end + DummyDriver::DummyConnection.connections.size.should eq(1) + end + end + + it "should reconnect if connection is lost and executing same statement" do + DummyDriver::DummyConnection.clear_connections + DB.open "dummy://localhost:1027?initial_pool_size=1&max_pool_size=1&retry_attempts=1" do |db| + db.exec("stmt1") + DummyDriver::DummyConnection.connections.size.should eq(1) + DummyDriver::DummyConnection.connections.first.disconnect! + db.exec("stmt1") + DummyDriver::DummyConnection.connections.size.should eq(2) + end + end + + it "should allow new connections if pool can increased and retry is not allowed" do + DummyDriver::DummyConnection.clear_connections + DB.open "dummy://localhost:1027?initial_pool_size=1&max_pool_size=2&retry_attempts=0" do |db| + db.query("stmt1") + DummyDriver::DummyConnection.connections.size.should eq(1) + db.query("stmt1") + DummyDriver::DummyConnection.connections.size.should eq(2) + end + end + + it "should not return connection to pool if checkout explicitly" do + DummyDriver::DummyConnection.clear_connections + DB.open "dummy://localhost:1027?initial_pool_size=1&max_pool_size=1&retry_attempts=0" do |db| + the_cnn = uninitialized DB::Connection + db.using_connection do |cnn| + the_cnn = cnn + db.pool.is_available?(cnn).should be_false + 3.times do + cnn.exec("stmt1") + db.pool.is_available?(cnn).should be_false + end + end + db.pool.is_available?(the_cnn).should be_true + end + end + + it "should checkout different connections until they are released" do + DummyDriver::DummyConnection.clear_connections + DB.open "dummy://localhost:1027?initial_pool_size=1&max_pool_size=2&retry_attempts=0" do |db| + the_first_cnn = uninitialized DB::Connection + the_second_cnn = uninitialized DB::Connection + + the_first_cnn = db.checkout + the_second_cnn = db.checkout + the_second_cnn.should_not eq(the_first_cnn) + db.pool.is_available?(the_first_cnn).should be_false + db.pool.is_available?(the_second_cnn).should be_false + + the_first_cnn.release + db.pool.is_available?(the_first_cnn).should be_true + db.pool.is_available?(the_second_cnn).should be_false + + db.checkout.should eq(the_first_cnn) + the_first_cnn.release + the_second_cnn.release + end + end + + it "should not return explicit checked out connections to the pool after query" do + DummyDriver::DummyConnection.clear_connections + DB.open "dummy://localhost:1027?initial_pool_size=1&max_pool_size=2&retry_attempts=0" do |db| + cnn = db.checkout + + cnn.query_all("1", as: String) + + db.pool.is_available?(cnn).should be_false + cnn.release + db.pool.is_available?(cnn).should be_true + end + end + + it "should return connection to the pool if prepared statement is unable to be built" do + connection = uninitialized DB::Connection + with_dummy "dummy://localhost:1027?initial_pool_size=1" do |db| + connection = DummyDriver::DummyConnection.connections.first + expect_raises DB::Error do + db.prepared.exec("syntax error") + end + db.pool.is_available?(connection).should be_true + end + end + + it "should return connection to the pool if unprepared statement is unable to be built" do + connection = uninitialized DB::Connection + with_dummy "dummy://localhost:1027?initial_pool_size=1" do |db| + connection = DummyDriver::DummyConnection.connections.first + expect_raises DB::Error do + db.unprepared.exec("syntax error") + end + db.pool.is_available?(connection).should be_true + end + end + + describe "prepared_statements connection option" do + it "defaults to true" do + with_dummy "dummy://localhost:1027" do |db| + db.prepared_statements?.should be_true + end + end + + it "can be set to false" do + with_dummy "dummy://localhost:1027?prepared_statements=false" do |db| + db.prepared_statements?.should be_false + end + end + + it "is copied to connections (false)" do + with_dummy "dummy://localhost:1027?prepared_statements=false&initial_pool_size=1" do |db| + connection = DummyDriver::DummyConnection.connections.first + connection.prepared_statements?.should be_false + end + end + + it "is copied to connections (true)" do + with_dummy "dummy://localhost:1027?prepared_statements=true&initial_pool_size=1" do |db| + connection = DummyDriver::DummyConnection.connections.first + connection.prepared_statements?.should be_true + end + end + + it "should build prepared statements if true" do + with_dummy "dummy://localhost:1027?prepared_statements=true" do |db| + db.build("the query").should be_a(DB::PoolPreparedStatement) + end + end + + it "should build unprepared statements if false" do + with_dummy "dummy://localhost:1027?prepared_statements=false" do |db| + db.build("the query").should be_a(DB::PoolUnpreparedStatement) + end + end + + it "should be overrided by dsl" do + with_dummy "dummy://localhost:1027?prepared_statements=true" do |db| + stmt = db.unprepared.query("the query").statement.as(DummyDriver::DummyStatement) + stmt.prepared?.should be_false + end + + with_dummy "dummy://localhost:1027?prepared_statements=false" do |db| + stmt = db.prepared.query("the query").statement.as(DummyDriver::DummyStatement) + stmt.prepared?.should be_true + end + end + end + + describe "unprepared statements in pool" do + it "creating statements should not create new connections" do + with_dummy "dummy://localhost:1027?initial_pool_size=1" do |db| + stmt1 = db.unprepared.build("query1") + stmt2 = db.unprepared.build("query2") + DummyDriver::DummyConnection.connections.size.should eq(1) + end + end + + it "simultaneous statements should go to different connections" do + with_dummy "dummy://localhost:1027?initial_pool_size=1" do |db| + rs1 = db.unprepared.query("query1") + rs2 = db.unprepared.query("query2") + rs1.statement.connection.should_not eq(rs2.statement.connection) + DummyDriver::DummyConnection.connections.size.should eq(2) + end + end + + it "sequential statements should go to different connections" do + with_dummy "dummy://localhost:1027?initial_pool_size=1" do |db| + rs1 = db.unprepared.query("query1") + rs1.close + rs2 = db.unprepared.query("query2") + rs2.close + rs1.statement.connection.should eq(rs2.statement.connection) + DummyDriver::DummyConnection.connections.size.should eq(1) + end + end + end +end diff --git a/lib/db/spec/db_spec.cr b/lib/db/spec/db_spec.cr new file mode 100644 index 00000000..d5836c3a --- /dev/null +++ b/lib/db/spec/db_spec.cr @@ -0,0 +1,134 @@ +require "./spec_helper" + +private def connections + DummyDriver::DummyConnection.connections +end + +describe DB do + it "should get driver class by name" do + DB.driver_class("dummy").should eq(DummyDriver) + end + + it "should instantiate driver with connection uri" do + db = DB.open "dummy://localhost:1027" + db.driver.should be_a(DummyDriver) + db.uri.scheme.should eq("dummy") + db.uri.host.should eq("localhost") + db.uri.port.should eq(1027) + end + + it "should create a connection and close it" do + DummyDriver::DummyConnection.clear_connections + DB.open "dummy://localhost" do |db| + end + connections.size.should eq(1) + connections.first.closed?.should be_true + end + + it "should create a connection and close it" do + DummyDriver::DummyConnection.clear_connections + DB.connect "dummy://localhost" do |cnn| + cnn.should be_a(DummyDriver::DummyConnection) + end + connections.size.should eq(1) + connections.first.closed?.should be_true + end + + it "should create a connection and wait for explicit closing" do + DummyDriver::DummyConnection.clear_connections + cnn = DB.connect "dummy://localhost" + cnn.should be_a(DummyDriver::DummyConnection) + connections.size.should eq(1) + connections.first.closed?.should be_false + cnn.close + connections.first.closed?.should be_true + end + + it "query should close result_set" do + with_witness do |w| + with_dummy do |db| + db.query "1,2" do + break + end + + w.check + DummyDriver::DummyResultSet.last_result_set.closed?.should be_true + end + end + end + + it "scalar should close statement" do + with_dummy do |db| + db.scalar "1" + DummyDriver::DummyResultSet.last_result_set.closed?.should be_true + end + end + + it "initially a single connection should be created" do + with_dummy do |db| + connections.size.should eq(1) + end + end + + it "the connection should be closed after db usage" do + with_dummy do |db| + connections.first.closed?.should be_false + end + connections.first.closed?.should be_true + end + + it "should raise if the sole connection is been used" do + with_dummy "dummy://host?max_pool_size=1&checkout_timeout=0.5" do |db| + db.query "1" do |rs| + expect_raises DB::PoolTimeout do + db.scalar "2" + end + end + end + end + + it "should use 'unlimited' connections by default" do + with_dummy "dummy://host?checkout_timeout=0.5" do |db| + rs = [] of DB::ResultSet + 500.times do + rs << db.query "1" + end + DummyDriver::DummyConnection.connections.size.should eq(500) + end + end + + it "exec should return to pool" do + with_dummy do |db| + db.exec "foo" + db.exec "bar" + end + end + + it "scalar should return to pool" do + with_dummy do |db| + db.scalar "foo" + db.scalar "bar" + end + end + + it "gives nice error message when no driver is registered for schema (#21)" do + expect_raises(ArgumentError, %(no driver was registered for the schema "foobar", did you maybe forget to require the database driver?)) do + DB.open "foobar://baz" + end + end + + it "should parse boolean query string params" do + DB.fetch_bool(HTTP::Params.parse("foo=true"), "foo", false).should be_true + DB.fetch_bool(HTTP::Params.parse("foo=True"), "foo", false).should be_true + + DB.fetch_bool(HTTP::Params.parse("foo=false"), "foo", true).should be_false + DB.fetch_bool(HTTP::Params.parse("foo=False"), "foo", true).should be_false + + DB.fetch_bool(HTTP::Params.parse("bar=true"), "foo", false).should be_false + DB.fetch_bool(HTTP::Params.parse("bar=true"), "foo", true).should be_true + + expect_raises(ArgumentError, %(invalid "other" value for option "foo")) do + DB.fetch_bool(HTTP::Params.parse("foo=other"), "foo", true) + end + end +end diff --git a/lib/db/spec/disposable_spec.cr b/lib/db/spec/disposable_spec.cr new file mode 100644 index 00000000..3b749d89 --- /dev/null +++ b/lib/db/spec/disposable_spec.cr @@ -0,0 +1,31 @@ +require "./spec_helper" + +class ADisposable + include DB::Disposable + @raise = false + + property raise + + protected def do_close + raise "Unable to close" if @raise + end +end + +describe DB::Disposable do + it "should mark as closed if able to close" do + obj = ADisposable.new + obj.closed?.should be_false + obj.close + obj.closed?.should be_true + end + + it "should not mark as closed if unable to close" do + obj = ADisposable.new + obj.raise = true + obj.closed?.should be_false + expect_raises Exception do + obj.close + end + obj.closed?.should be_false + end +end diff --git a/lib/db/spec/dummy_driver.cr b/lib/db/spec/dummy_driver.cr new file mode 100644 index 00000000..830d3719 --- /dev/null +++ b/lib/db/spec/dummy_driver.cr @@ -0,0 +1,268 @@ +require "spec" +require "../src/db" + +class DummyDriver < DB::Driver + def build_connection(context : DB::ConnectionContext) : DB::Connection + DummyConnection.new(context) + end + + class DummyConnection < DB::Connection + def initialize(context) + super(context) + @connected = true + @@connections ||= [] of DummyConnection + @@connections.not_nil! << self + end + + def self.connections + @@connections.not_nil! + end + + def self.clear_connections + @@connections.try &.clear + end + + def build_prepared_statement(query) : DB::Statement + DummyStatement.new(self, query, true) + end + + def build_unprepared_statement(query) : DB::Statement + DummyStatement.new(self, query, false) + end + + def last_insert_id : Int64 + 0 + end + + def check + raise DB::ConnectionLost.new(self) unless @connected + end + + def disconnect! + @connected = false + end + + def create_transaction + DummyTransaction.new(self) + end + + protected def do_close + super + end + end + + class DummyTransaction < DB::TopLevelTransaction + getter committed = false + getter rolledback = false + + def initialize(connection) + super(connection) + end + + def commit + super + @committed = true + end + + def rollback + super + @rolledback = true + end + + protected def create_save_point_transaction(parent, savepoint_name : String) + DummySavePointTransaction.new(parent, savepoint_name) + end + end + + class DummySavePointTransaction < DB::SavePointTransaction + getter committed = false + getter rolledback = false + + def initialize(parent, savepoint_name) + super(parent, savepoint_name) + end + + def commit + super + @committed = true + end + + def rollback + super + @rolledback = true + end + end + + class DummyStatement < DB::Statement + property params + + def initialize(connection, @query : String, @prepared : Bool) + @params = Hash(Int32 | String, DB::Any).new + super(connection) + raise DB::Error.new(query) if query == "syntax error" + end + + protected def perform_query(args : Enumerable) : DB::ResultSet + @connection.as(DummyConnection).check + set_params args + DummyResultSet.new self, @query + end + + protected def perform_exec(args : Enumerable) : DB::ExecResult + @connection.as(DummyConnection).check + set_params args + raise DB::Error.new("forced exception due to query") if @query == "raise" + DB::ExecResult.new 0i64, 0_i64 + end + + private def set_params(args) + @params.clear + args.each_with_index do |arg, index| + set_param(index, arg) + end + end + + private def set_param(index, value : DB::Any) + @params[index] = value + end + + private def set_param(index, value) + raise "not implemented for #{value.class}" + end + + def prepared? + @prepared + end + + protected def do_close + super + end + end + + class DummyResultSet < DB::ResultSet + @top_values : Array(Array(String)) + @values : Array(String)? + + @@last_result_set : self? + + def initialize(statement, query) + super(statement) + @top_values = query.split.map { |r| r.split(',') }.to_a + @column_count = @top_values.size > 0 ? @top_values[0].size : 2 + + @@last_result_set = self + end + + protected def do_close + super + end + + def self.last_result_set + @@last_result_set.not_nil! + end + + def move_next : Bool + @values = @top_values.shift? + !!@values + end + + def column_count : Int32 + @column_count + end + + def column_name(index) : String + "c#{index}" + end + + def read + n = @values.not_nil!.shift? + raise "end of row" if n.is_a?(Nil) + return nil if n == "NULL" + + if n == "?" + return (@statement.as(DummyStatement)).params[0] + end + + return n + end + + def read(t : String.class) + read.to_s + end + + def read(t : String?.class) + read.try &.to_s + end + + def read(t : Int32.class) + read(String).to_i32 + end + + def read(t : Int32?.class) + read(String?).try &.to_i32 + end + + def read(t : Int64.class) + read(String).to_i64 + end + + def read(t : Int64?.class) + read(String?).try &.to_i64 + end + + def read(t : Float32.class) + read(String).to_f32 + end + + def read(t : Float64.class) + read(String).to_f64 + end + + def read(t : Bytes.class) + case value = read + when String + ary = value.bytes + Slice.new(ary.to_unsafe, ary.size) + when Bytes + value + else + raise "#{value} is not convertible to Bytes" + end + end + end +end + +DB.register_driver "dummy", DummyDriver + +class Witness + getter count + + def initialize(@count = 1) + end + + def check + @count -= 1 + end +end + +def with_witness(count = 1) + w = Witness.new(count) + yield w + w.count.should eq(0), "The expected coverage was unmet" +end + +def with_dummy(uri : String = "dummy://host?checkout_timeout=0.5") + DummyDriver::DummyConnection.clear_connections + + DB.open uri do |db| + yield db + end +end + +def with_dummy_connection(options = "") + with_dummy("dummy://host?checkout_timeout=0.5&#{options}") do |db| + db.using_connection do |cnn| + yield cnn.as(DummyDriver::DummyConnection) + end + end +end diff --git a/lib/db/spec/dummy_driver_spec.cr b/lib/db/spec/dummy_driver_spec.cr new file mode 100644 index 00000000..688a71d4 --- /dev/null +++ b/lib/db/spec/dummy_driver_spec.cr @@ -0,0 +1,310 @@ +require "./spec_helper" + +describe DummyDriver do + it "with_dummy executes the block with a database" do + with_witness do |w| + with_dummy do |db| + w.check + db.should be_a(DB::Database) + end + end + end + + describe DummyDriver::DummyStatement do + it "should enumerate split rows by spaces" do + with_dummy do |db| + rs = db.query("") + rs.move_next.should be_false + rs.close + + rs = db.query("a,b") + rs.move_next.should be_true + rs.move_next.should be_false + rs.close + + rs = db.query("a,b 1,2") + rs.move_next.should be_true + rs.move_next.should be_true + rs.move_next.should be_false + rs.close + + rs = db.query("a,b 1,2 c,d") + rs.move_next.should be_true + rs.move_next.should be_true + rs.move_next.should be_true + rs.move_next.should be_false + rs.close + end + end + + # it "should query with block should executes always" do + # with_witness do |w| + # with_dummy do |db| + # db.query "a" do |rs| + # w.check + # end + # end + # end + # end + # + # it "should query with block should executes always" do + # with_witness do |w| + # with_dummy do |db| + # db.query "lorem ipsum" do |rs| + # w.check + # end + # end + # end + # end + + it "should enumerate string fields" do + with_dummy do |db| + db.query "a,b 1,2" do |rs| + rs.move_next + rs.read(String).should eq("a") + rs.read(String).should eq("b") + rs.move_next + rs.read(String).should eq("1") + rs.read(String).should eq("2") + end + end + end + + it "should enumerate nil fields" do + with_dummy do |db| + db.query "a,NULL 1,NULL" do |rs| + rs.move_next + rs.read(String).should eq("a") + rs.read(String | Nil).should be_nil + rs.move_next + rs.read(Int64).should eq(1) + rs.read(Int64 | Nil).should be_nil + end + end + end + + it "should enumerate int64 fields" do + with_dummy do |db| + db.query "3,4 1,2" do |rs| + rs.move_next + rs.read(Int64).should eq(3i64) + rs.read(Int64).should eq(4i64) + rs.move_next + rs.read(Int64).should eq(1i64) + rs.read(Int64).should eq(2i64) + end + end + end + + it "should enumerate nillable int64 fields" do + with_dummy do |db| + db.query "3,4 1,NULL" do |rs| + rs.move_next + rs.read(Int64 | Nil).should eq(3i64) + rs.read(Int64 | Nil).should eq(4i64) + rs.move_next + rs.read(Int64 | Nil).should eq(1i64) + rs.read(Int64 | Nil).should be_nil + end + end + end + + describe "query one" do + it "queries" do + with_dummy do |db| + db.query_one("3,4", &.read(Int64, Int64)).should eq({3i64, 4i64}) + end + end + + it "raises if more than one row" do + with_dummy do |db| + expect_raises(DB::Error, "more than one row") do + db.query_one("3,4 5,6") { } + end + end + end + + it "raises if no rows" do + with_dummy do |db| + expect_raises(DB::Error, "no rows") do + db.query_one("") { } + end + end + end + + it "with as" do + with_dummy do |db| + db.query_one("3,4", as: {Int64, Int64}).should eq({3i64, 4i64}) + end + end + + it "with a named tuple" do + with_dummy do |db| + db.query_one("3,4", as: {a: Int64, b: Int64}).should eq({a: 3i64, b: 4i64}) + end + end + + it "with as, just one" do + with_dummy do |db| + db.query_one("3", as: Int64).should eq(3i64) + end + end + end + + describe "query one?" do + it "queries" do + with_dummy do |db| + value = db.query_one?("3,4", &.read(Int64, Int64)) + value.should eq({3i64, 4i64}) + value.should be_a(Tuple(Int64, Int64)?) + end + end + + it "raises if more than one row" do + with_dummy do |db| + expect_raises(DB::Error, "more than one row") do + db.query_one?("3,4 5,6") { } + end + end + end + + it "returns nil if no rows" do + with_dummy do |db| + db.query_one?("") { fail("block shouldn't be invoked") }.should be_nil + end + end + + it "with as" do + with_dummy do |db| + value = db.query_one?("3,4", as: {Int64, Int64}) + value.should be_a(Tuple(Int64, Int64)?) + value.should eq({3i64, 4i64}) + end + end + + it "with as" do + with_dummy do |db| + value = db.query_one?("3,4", as: {a: Int64, b: Int64}) + value.should be_a(NamedTuple(a: Int64, b: Int64)?) + value.should eq({a: 3i64, b: 4i64}) + end + end + + it "with as, no rows" do + with_dummy do |db| + value = db.query_one?("", as: {a: Int64, b: Int64}) + value.should be_nil + end + end + + it "with as, just one" do + with_dummy do |db| + value = db.query_one?("3", as: Int64) + value.should be_a(Int64?) + value.should eq(3i64) + end + end + end + + describe "query all" do + it "queries" do + with_dummy do |db| + ary = db.query_all "3,4 1,2", &.read(Int64, Int64) + ary.should eq([{3, 4}, {1, 2}]) + end + end + + it "queries with as" do + with_dummy do |db| + ary = db.query_all "3,4 1,2", as: {Int64, Int64} + ary.should eq([{3, 4}, {1, 2}]) + end + end + + it "queries with a named tuple" do + with_dummy do |db| + ary = db.query_all "3,4 1,2", as: {a: Int64, b: Int64} + ary.should eq([{a: 3, b: 4}, {a: 1, b: 2}]) + end + end + + it "queries with as, just one" do + with_dummy do |db| + ary = db.query_all "3 1", as: Int64 + ary.should eq([3, 1]) + end + end + end + + describe "query each" do + it "queries" do + with_dummy do |db| + i = 0 + db.query_each "3,4 1,2" do |rs| + case i + when 0 + rs.read(Int64, Int64).should eq({3i64, 4i64}) + when 1 + rs.read(Int64, Int64).should eq({1i64, 2i64}) + end + i += 1 + end + i.should eq(2) + end + end + end + + it "reads multiple values" do + with_dummy do |db| + db.query "3,4 1,2" do |rs| + rs.move_next + rs.read(Int64, Int64).should eq({3i64, 4i64}) + rs.move_next + rs.read(Int64, Int64).should eq({1i64, 2i64}) + end + end + end + + it "should enumerate blob fields" do + with_dummy do |db| + db.query("az,AZ") do |rs| + rs.move_next + ary = [97u8, 122u8] + rs.read(Bytes).should eq(Bytes.new(ary.to_unsafe, ary.size)) + ary = [65u8, 90u8] + rs.read(Bytes).should eq(Bytes.new(ary.to_unsafe, ary.size)) + end + end + end + + it "should get Nil scalars" do + with_dummy do |db| + db.scalar("NULL").should be_nil + end + end + + it "should raise executing raise query" do + with_dummy do |db| + expect_raises DB::Error do + db.exec "raise" + end + end + end + + {% for value in [1, 1_i64, "hello", 1.5, 1.5_f32] %} + it "should set positional arguments for {{value.id}}" do + with_dummy do |db| + db.scalar("?", {{value}}).should eq({{value}}) + end + end + {% end %} + + it "executes and selects blob" do + with_dummy do |db| + ary = UInt8[0x53, 0x51, 0x4C] + slice = Bytes.new(ary.to_unsafe, ary.size) + (db.scalar("?", slice).as(Bytes)).to_a.should eq(ary) + end + end + end +end diff --git a/lib/db/spec/mapping_spec.cr b/lib/db/spec/mapping_spec.cr new file mode 100644 index 00000000..dc8f7e60 --- /dev/null +++ b/lib/db/spec/mapping_spec.cr @@ -0,0 +1,194 @@ +require "./spec_helper" +require "base64" + +class SimpleMapping + DB.mapping({ + c0: Int32, + c1: String, + }) +end + +class NonStrictMapping + DB.mapping({ + c1: Int32, + c2: String, + }, strict: false) +end + +class MappingWithDefaults + DB.mapping({ + c0: {type: Int32, default: 10}, + c1: {type: String, default: "c"}, + }) +end + +class MappingWithNilables + DB.mapping({ + c0: {type: Int32, nilable: true, default: 10}, + c1: {type: String, nilable: true}, + }) +end + +class MappingWithNilTypes + DB.mapping({ + c0: {type: Int32?, default: 10}, + c1: String?, + }) +end + +class MappingWithNilUnionTypes + DB.mapping({ + c0: {type: Int32 | Nil, default: 10}, + c1: Nil | String, + }) +end + +class MappingWithKeys + DB.mapping({ + foo: {type: Int32, key: "c0"}, + bar: {type: String, key: "c1"}, + }) +end + +class MappingWithConverter + module Base64Converter + def self.from_rs(rs) + Base64.decode(rs.read(String)) + end + end + + DB.mapping({ + c0: {type: Slice(UInt8), converter: MappingWithConverter::Base64Converter}, + c1: {type: String}, + }) +end + +macro from_dummy(query, type) + with_dummy do |db| + rs = db.query({{ query }}) + rs.move_next + %obj = {{ type }}.new(rs) + rs.close + %obj + end +end + +macro expect_mapping(query, t, values) + %obj = from_dummy({{ query }}, {{ t }}) + %obj.should be_a({{ t }}) + {% for key, value in values %} + %obj.{{key.id}}.should eq({{value}}) + {% end %} +end + +describe "DB.mapping" do + it "should initialize a simple mapping" do + expect_mapping("1,a", SimpleMapping, {c0: 1, c1: "a"}) + end + + it "should fail to initialize a simple mapping if types do not match" do + expect_raises ArgumentError do + from_dummy("b,a", SimpleMapping) + end + end + + it "should fail to initialize a simple mapping if there is a missing column" do + expect_raises DB::MappingException do + from_dummy("1", SimpleMapping) + end + end + + it "should fail to initialize a simple mapping if there is an unexpected column" do + expect_raises DB::MappingException do + from_dummy("1,a,b", SimpleMapping) + end + end + + it "should initialize a non-strict mapping if there is an unexpected column" do + expect_mapping("1,2,a,b", NonStrictMapping, {c1: 2, c2: "a"}) + end + + it "should initialize a mapping with default values" do + expect_mapping("1,a", MappingWithDefaults, {c0: 1, c1: "a"}) + end + + it "should initialize a mapping using default values if columns are missing" do + expect_mapping("1", MappingWithDefaults, {c0: 1, c1: "c"}) + end + + it "should initialize a mapping using default values if values are nil and types are non nilable" do + expect_mapping("1,NULL", MappingWithDefaults, {c0: 1, c1: "c"}) + end + + it "should initialize a mapping with nilable set if columns are missing" do + expect_mapping("1", MappingWithNilables, {c0: 1, c1: nil}) + end + + it "should initialize a mapping with nilable set ignoring default value if NULL" do + expect_mapping("NULL,a", MappingWithNilables, {c0: nil, c1: "a"}) + end + + it "should initialize a mapping with nilable types if columns are missing" do + expect_mapping("1", MappingWithNilTypes, {c0: 1, c1: nil}) + expect_mapping("1", MappingWithNilUnionTypes, {c0: 1, c1: nil}) + end + + it "should initialize a mapping with nilable types ignoring default value if NULL" do + expect_mapping("NULL,a", MappingWithNilTypes, {c0: nil, c1: "a"}) + expect_mapping("NULL,a", MappingWithNilUnionTypes, {c0: nil, c1: "a"}) + end + + it "should initialize a mapping with different keys" do + expect_mapping("1,a", MappingWithKeys, {foo: 1, bar: "a"}) + end + + it "should initialize a mapping with a value converter" do + expect_mapping("Zm9v,a", MappingWithConverter, {c0: "foo".to_slice, c1: "a"}) + end + + it "should initialize multiple instances from a single resultset" do + with_dummy do |db| + db.query("1,a 2,b") do |rs| + objs = SimpleMapping.from_rs(rs) + objs.size.should eq(2) + objs[0].c0.should eq(1) + objs[0].c1.should eq("a") + objs[1].c0.should eq(2) + objs[1].c1.should eq("b") + end + end + end + + it "Class.from_rs should close resultset" do + with_dummy do |db| + rs = db.query("1,a 2,b") + objs = SimpleMapping.from_rs(rs) + rs.closed?.should be_true + + objs.size.should eq(2) + objs[0].c0.should eq(1) + objs[0].c1.should eq("a") + objs[1].c0.should eq(2) + objs[1].c1.should eq("b") + end + end + + it "should initialize from a query_one" do + with_dummy do |db| + obj = db.query_one "1,a", as: SimpleMapping + obj.c0.should eq(1) + obj.c1.should eq("a") + end + end + + it "should initialize from a query_all" do + with_dummy do |db| + objs = db.query_all "1,a 2,b", as: SimpleMapping + objs.size.should eq(2) + objs[0].c0.should eq(1) + objs[0].c1.should eq("a") + objs[1].c0.should eq(2) + objs[1].c1.should eq("b") + end + end +end diff --git a/lib/db/spec/pool_spec.cr b/lib/db/spec/pool_spec.cr new file mode 100644 index 00000000..6a1612cc --- /dev/null +++ b/lib/db/spec/pool_spec.cr @@ -0,0 +1,206 @@ +require "./spec_helper" + +class ShouldSleepingOp + @is_sleeping = false + getter is_sleeping + getter sleep_happened + + def initialize + @sleep_happened = Channel(Nil).new + end + + def should_sleep + s = self + @is_sleeping = true + spawn do + sleep 0.1 + s.is_sleeping.should be_true + s.sleep_happened.send(nil) + end + yield + @is_sleeping = false + end + + def wait_for_sleep + @sleep_happened.receive + end +end + +class WaitFor + def initialize + @channel = Channel(Nil).new + end + + def wait + @channel.receive + end + + def check + @channel.send(nil) + end +end + +class Closable + include DB::Disposable + property before_checkout_called : Bool = false + property after_release_called : Bool = false + + protected def do_close + end + + def before_checkout + @before_checkout_called = true + end + + def after_release + @after_release_called = true + end +end + +describe DB::Pool do + it "should use proc to create objects" do + block_called = 0 + pool = DB::Pool.new(initial_pool_size: 3) { block_called += 1; Closable.new } + block_called.should eq(3) + end + + it "should get resource" do + pool = DB::Pool.new { Closable.new } + resource = pool.checkout + resource.should be_a Closable + resource.before_checkout_called.should be_true + end + + it "should be available if not checkedout" do + resource = uninitialized Closable + pool = DB::Pool.new(initial_pool_size: 1) { resource = Closable.new } + pool.is_available?(resource).should be_true + end + + it "should not be available if checkedout" do + pool = DB::Pool.new { Closable.new } + resource = pool.checkout + pool.is_available?(resource).should be_false + end + + it "should be available if returned" do + pool = DB::Pool.new { Closable.new } + resource = pool.checkout + resource.after_release_called.should be_false + pool.release resource + pool.is_available?(resource).should be_true + resource.after_release_called.should be_true + end + + it "should wait for available resource" do + pool = DB::Pool.new(max_pool_size: 1, initial_pool_size: 1) { Closable.new } + + b_cnn_request = ShouldSleepingOp.new + wait_a = WaitFor.new + wait_b = WaitFor.new + + spawn do + a_cnn = pool.checkout + b_cnn_request.wait_for_sleep + pool.release a_cnn + + wait_a.check + end + + spawn do + b_cnn_request.should_sleep do + pool.checkout + end + + wait_b.check + end + + wait_a.wait + wait_b.wait + end + + it "should create new if max was not reached" do + block_called = 0 + pool = DB::Pool.new(max_pool_size: 2, initial_pool_size: 1) { block_called += 1; Closable.new } + block_called.should eq 1 + pool.checkout + block_called.should eq 1 + pool.checkout + block_called.should eq 2 + end + + it "should reuse returned resources" do + all = [] of Closable + pool = DB::Pool.new(max_pool_size: 2, initial_pool_size: 1) { Closable.new.tap { |c| all << c } } + pool.checkout + b1 = pool.checkout + pool.release b1 + b2 = pool.checkout + + b1.should eq b2 + all.size.should eq 2 + end + + it "should close available and total" do + all = [] of Closable + pool = DB::Pool.new(max_pool_size: 2, initial_pool_size: 1) { Closable.new.tap { |c| all << c } } + a = pool.checkout + b = pool.checkout + pool.release b + all.size.should eq 2 + + all[0].closed?.should be_false + all[1].closed?.should be_false + pool.close + all[0].closed?.should be_true + all[1].closed?.should be_true + end + + it "should timeout" do + pool = DB::Pool.new(max_pool_size: 1, checkout_timeout: 0.1) { Closable.new } + pool.checkout + expect_raises DB::PoolTimeout do + pool.checkout + end + end + + it "should be able to release after a timeout" do + pool = DB::Pool.new(max_pool_size: 1, checkout_timeout: 0.1) { Closable.new } + a = pool.checkout + pool.checkout rescue nil + pool.release a + end + + it "should close if max idle amount is reached" do + all = [] of Closable + pool = DB::Pool.new(max_pool_size: 3, max_idle_pool_size: 1) { Closable.new.tap { |c| all << c } } + pool.checkout + pool.checkout + pool.checkout + + all.size.should eq 3 + all.any?(&.closed?).should be_false + pool.release all[0] + + all.any?(&.closed?).should be_false + pool.release all[1] + + all[0].closed?.should be_false + all[1].closed?.should be_true + all[2].closed?.should be_false + end + + it "should create resource after max_pool was reached if idle forced some close up" do + all = [] of Closable + pool = DB::Pool.new(max_pool_size: 3, max_idle_pool_size: 1) { Closable.new.tap { |c| all << c } } + pool.checkout + pool.checkout + pool.checkout + pool.release all[0] + pool.release all[1] + pool.checkout + pool.checkout + + all.size.should eq 4 + end +end diff --git a/lib/db/spec/result_set_spec.cr b/lib/db/spec/result_set_spec.cr new file mode 100644 index 00000000..bb633fdd --- /dev/null +++ b/lib/db/spec/result_set_spec.cr @@ -0,0 +1,67 @@ +require "./spec_helper" + +class DummyException < Exception +end + +describe DB::ResultSet do + it "should enumerate records using each" do + nums = [] of Int32 + + with_dummy do |db| + db.query "3,4 1,2" do |rs| + rs.each do + nums << rs.read(Int32) + nums << rs.read(Int32) + end + end + end + + nums.should eq([3, 4, 1, 2]) + end + + it "should close ResultSet after query" do + with_dummy do |db| + the_rs = uninitialized DB::ResultSet + db.query "3,4 1,2" do |rs| + the_rs = rs + end + the_rs.closed?.should be_true + end + end + + it "should close ResultSet after query even with exception" do + with_dummy do |db| + the_rs = uninitialized DB::ResultSet + begin + db.query "3,4 1,2" do |rs| + the_rs = rs + raise DummyException.new + end + rescue DummyException + end + the_rs.closed?.should be_true + end + end + + it "should enumerate columns" do + cols = [] of String + + with_dummy do |db| + db.query "3,4 1,2" do |rs| + rs.each_column do |col| + cols << col + end + end + end + + cols.should eq(["c0", "c1"]) + end + + it "gets all column names" do + with_dummy do |db| + db.query "1,2" do |rs| + rs.column_names.should eq(%w(c0 c1)) + end + end + end +end diff --git a/lib/db/spec/save_point_transaction_spec.cr b/lib/db/spec/save_point_transaction_spec.cr new file mode 100644 index 00000000..97d5ab45 --- /dev/null +++ b/lib/db/spec/save_point_transaction_spec.cr @@ -0,0 +1,160 @@ +require "./spec_helper" + +private class FooException < Exception +end + +private def with_dummy_top_transaction + with_dummy_connection do |cnn| + cnn.transaction do |tx| + yield tx.as(DummyDriver::DummyTransaction), cnn + end + end +end + +private def with_dummy_nested_transaction + with_dummy_connection do |cnn| + cnn.transaction do |tx| + tx.transaction do |nested| + yield nested.as(DummyDriver::DummySavePointTransaction), cnn + end + end + end +end + +describe DB::SavePointTransaction do + {% for context in [:with_dummy_top_transaction, :with_dummy_nested_transaction] %} + describe "{{context.id}}" do + it "begin/commit transaction from parent transaction" do + {{context.id}} do |parent_tx| + tx = parent_tx.begin_transaction + tx.commit + end + end + + it "begin/rollback transaction from parent transaction" do + {{context.id}} do |parent_tx| + tx = parent_tx.begin_transaction + tx.rollback + end + end + + it "raise if begin over existing transaction" do + {{context.id}} do |parent_tx| + parent_tx.begin_transaction + expect_raises(DB::Error, "There is an existing nested transaction in this transaction") do + parent_tx.begin_transaction + end + end + end + + it "allow sequential transactions" do + {{context.id}} do |parent_tx| + tx = parent_tx.begin_transaction + tx.rollback + + tx = parent_tx.begin_transaction + tx.commit + end + end + + it "transaction with block from parent transaction should be committed" do + t = uninitialized DummyDriver::DummySavePointTransaction + + with_witness do |w| + {{context.id}} do |parent_tx| + parent_tx.transaction do |tx| + if tx.is_a?(DummyDriver::DummySavePointTransaction) + t = tx + w.check + end + end + end + end + + t.committed.should be_true + t.rolledback.should be_false + end + end + {% end %} + + it "only nested transaction with block from parent transaction should be rolledback if raise DB::Rollback" do + top = uninitialized DummyDriver::DummyTransaction + t = uninitialized DummyDriver::DummySavePointTransaction + + with_witness do |w| + with_dummy_top_transaction do |parent_tx| + top = parent_tx + parent_tx.transaction do |tx| + if tx.is_a?(DummyDriver::DummySavePointTransaction) + t = tx + w.check + end + raise DB::Rollback.new + end + end + end + + t.rolledback.should be_true + t.committed.should be_false + + top.rolledback.should be_false + top.committed.should be_true + end + + it "only nested transaction with block from parent nested transaction should be rolledback if raise DB::Rollback" do + top = uninitialized DummyDriver::DummySavePointTransaction + t = uninitialized DummyDriver::DummySavePointTransaction + + with_witness do |w| + with_dummy_nested_transaction do |parent_tx| + top = parent_tx + parent_tx.transaction do |tx| + if tx.is_a?(DummyDriver::DummySavePointTransaction) + t = tx + w.check + end + raise DB::Rollback.new + end + end + end + + t.rolledback.should be_true + t.committed.should be_false + + top.rolledback.should be_false + top.committed.should be_true + end + + it "releasing result_set from within inner transaction should not return connection to pool" do + cnn = uninitialized DB::Connection + with_dummy do |db| + db.transaction do |tx| + tx.transaction do |inner| + cnn = inner.connection + cnn.scalar "1" + db.pool.is_available?(cnn).should be_false + end + db.pool.is_available?(cnn).should be_false + end + db.pool.is_available?(cnn).should be_true + end + end + + it "releasing result_set from within inner inner transaction should not return connection to pool" do + cnn = uninitialized DB::Connection + with_dummy do |db| + db.transaction do |tx| + tx.transaction do |inner| + inner.transaction do |inner_inner| + cnn = inner_inner.connection + cnn.scalar "1" + db.pool.is_available?(cnn).should be_false + end + db.pool.is_available?(cnn).should be_false + end + db.pool.is_available?(cnn).should be_false + end + db.pool.is_available?(cnn).should be_true + end + end +end diff --git a/lib/db/spec/spec_helper.cr b/lib/db/spec/spec_helper.cr new file mode 100644 index 00000000..a3a537d2 --- /dev/null +++ b/lib/db/spec/spec_helper.cr @@ -0,0 +1,3 @@ +require "spec" +require "./dummy_driver" +require "../src/db" diff --git a/lib/db/spec/statement_spec.cr b/lib/db/spec/statement_spec.cr new file mode 100644 index 00000000..fcdd05ad --- /dev/null +++ b/lib/db/spec/statement_spec.cr @@ -0,0 +1,157 @@ +require "./spec_helper" + +describe DB::Statement do + it "should build prepared statements" do + with_dummy_connection do |cnn| + prepared = cnn.prepared("the query") + prepared.should be_a(DB::Statement) + prepared.as(DummyDriver::DummyStatement).prepared?.should be_true + end + end + + it "should build unprepared statements" do + with_dummy_connection("prepared_statements=false") do |cnn| + prepared = cnn.unprepared("the query") + prepared.should be_a(DB::Statement) + prepared.as(DummyDriver::DummyStatement).prepared?.should be_false + end + end + + describe "prepared_statements flag" do + it "should build prepared statements if true" do + with_dummy_connection("prepared_statements=true") do |cnn| + stmt = cnn.query("the query").statement + stmt.as(DummyDriver::DummyStatement).prepared?.should be_true + end + end + + it "should build unprepared statements if false" do + with_dummy_connection("prepared_statements=false") do |cnn| + stmt = cnn.query("the query").statement + stmt.as(DummyDriver::DummyStatement).prepared?.should be_false + end + end + end + + it "should initialize positional params in query" do + with_dummy_connection do |cnn| + stmt = cnn.prepared("the query").as(DummyDriver::DummyStatement) + stmt.query "a", 1, nil + stmt.params[0].should eq("a") + stmt.params[1].should eq(1) + stmt.params[2].should eq(nil) + end + end + + it "should initialize positional params in query with array" do + with_dummy_connection do |cnn| + stmt = cnn.prepared("the query").as(DummyDriver::DummyStatement) + stmt.query ["a", 1, nil] + stmt.params[0].should eq("a") + stmt.params[1].should eq(1) + stmt.params[2].should eq(nil) + end + end + + it "should initialize positional params in exec" do + with_dummy_connection do |cnn| + stmt = cnn.prepared("the query").as(DummyDriver::DummyStatement) + stmt.exec "a", 1, nil + stmt.params[0].should eq("a") + stmt.params[1].should eq(1) + stmt.params[2].should eq(nil) + end + end + + it "should initialize positional params in exec with array" do + with_dummy_connection do |cnn| + stmt = cnn.prepared("the query").as(DummyDriver::DummyStatement) + stmt.exec ["a", 1, nil] + stmt.params[0].should eq("a") + stmt.params[1].should eq(1) + stmt.params[2].should eq(nil) + end + end + + it "should initialize positional params in scalar" do + with_dummy_connection do |cnn| + stmt = cnn.prepared("the query").as(DummyDriver::DummyStatement) + stmt.scalar "a", 1, nil + stmt.params[0].should eq("a") + stmt.params[1].should eq(1) + stmt.params[2].should eq(nil) + end + end + + it "query with block should not close statement" do + with_dummy_connection do |cnn| + stmt = cnn.prepared "3,4 1,2" + stmt.query + stmt.closed?.should be_false + end + end + + it "closing connection should close statement" do + stmt = uninitialized DB::Statement + with_dummy_connection do |cnn| + stmt = cnn.prepared "3,4 1,2" + stmt.query + end + stmt.closed?.should be_true + end + + it "query with block should not close statement" do + with_dummy_connection do |cnn| + stmt = cnn.prepared "3,4 1,2" + stmt.query do |rs| + end + stmt.closed?.should be_false + end + end + + it "query should not close statement" do + with_dummy_connection do |cnn| + stmt = cnn.prepared "3,4 1,2" + stmt.query do |rs| + end + stmt.closed?.should be_false + end + end + + it "scalar should not close statement" do + with_dummy_connection do |cnn| + stmt = cnn.prepared "3,4 1,2" + stmt.scalar + stmt.closed?.should be_false + end + end + + it "exec should not close statement" do + with_dummy_connection do |cnn| + stmt = cnn.prepared "3,4 1,2" + stmt.exec + stmt.closed?.should be_false + end + end + + it "connection should cache statements by query" do + with_dummy_connection do |cnn| + rs = cnn.prepared.query "1, ?", 2 + stmt = rs.statement + rs.close + + rs = cnn.prepared.query "1, ?", 4 + rs.statement.should be(stmt) + end + end + + it "connection should be released if error occurs during exec" do + with_dummy do |db| + expect_raises DB::Error do + db.exec "raise" + end + DummyDriver::DummyConnection.connections.size.should eq(1) + db.pool.is_available?(DummyDriver::DummyConnection.connections.first) + end + end +end diff --git a/lib/db/spec/transaction_spec.cr b/lib/db/spec/transaction_spec.cr new file mode 100644 index 00000000..50ca671f --- /dev/null +++ b/lib/db/spec/transaction_spec.cr @@ -0,0 +1,178 @@ +require "./spec_helper" + +private class FooException < Exception +end + +describe DB::Transaction do + it "begin/commit transaction from connection" do + with_dummy_connection do |cnn| + tx = cnn.begin_transaction + tx.commit + end + end + + it "begin/rollback transaction from connection" do + with_dummy_connection do |cnn| + tx = cnn.begin_transaction + tx.rollback + end + end + + it "raise if begin over existing transaction" do + with_dummy_connection do |cnn| + cnn.begin_transaction + expect_raises(DB::Error, "There is an existing transaction in this connection") do + cnn.begin_transaction + end + end + end + + it "allow sequential transactions" do + with_dummy_connection do |cnn| + tx = cnn.begin_transaction + tx.rollback + + tx = cnn.begin_transaction + tx.commit + end + end + + it "transaction with block from connection should be committed" do + t = uninitialized DummyDriver::DummyTransaction + + with_witness do |w| + with_dummy_connection do |cnn| + cnn.transaction do |tx| + if tx.is_a?(DummyDriver::DummyTransaction) + t = tx + w.check + end + end + end + end + + t.committed.should be_true + t.rolledback.should be_false + end + + it "transaction with block from connection should be rolledback if raise DB::Rollback" do + t = uninitialized DummyDriver::DummyTransaction + + with_witness do |w| + with_dummy_connection do |cnn| + cnn.transaction do |tx| + if tx.is_a?(DummyDriver::DummyTransaction) + t = tx + w.check + end + raise DB::Rollback.new + end + end + end + + t.rolledback.should be_true + t.committed.should be_false + end + + it "transaction with block from connection should be rolledback if raise" do + t = uninitialized DummyDriver::DummyTransaction + + with_witness do |w| + with_dummy_connection do |cnn| + expect_raises(FooException) do + cnn.transaction do |tx| + if tx.is_a?(DummyDriver::DummyTransaction) + t = tx + w.check + end + raise FooException.new + end + end + end + end + + t.rolledback.should be_true + t.committed.should be_false + end + + it "transaction can be committed within block" do + with_dummy_connection do |cnn| + cnn.transaction do |tx| + tx.commit + end + end + end + + it "transaction can be rolledback within block" do + with_dummy_connection do |cnn| + cnn.transaction do |tx| + tx.rollback + end + end + end + + it "transaction can be rolledback within block and later raise" do + with_dummy_connection do |cnn| + expect_raises(FooException) do + cnn.transaction do |tx| + tx.rollback + raise FooException.new + end + end + end + end + + it "transaction can be rolledback within block and later raise DB::Rollback without forwarding it" do + with_dummy_connection do |cnn| + cnn.transaction do |tx| + tx.rollback + raise DB::Rollback.new + end + end + end + + it "transaction can't be committed twice" do + with_dummy_connection do |cnn| + cnn.transaction do |tx| + tx.commit + expect_raises(DB::Error, "Transaction already closed") do + tx.commit + end + end + end + end + + it "transaction can't be rolledback twice" do + with_dummy_connection do |cnn| + cnn.transaction do |tx| + tx.rollback + expect_raises(DB::Error, "Transaction already closed") do + tx.rollback + end + end + end + end + + it "return connection to pool after transaction block in db" do + DummyDriver::DummyConnection.clear_connections + + with_dummy do |db| + db.transaction do |tx| + db.pool.is_available?(DummyDriver::DummyConnection.connections.first).should be_false + end + db.pool.is_available?(DummyDriver::DummyConnection.connections.first).should be_true + end + end + + it "releasing result_set from within transaction should not return connection to pool" do + cnn = uninitialized DB::Connection + with_dummy do |db| + db.transaction do |tx| + cnn = tx.connection + cnn.scalar "1" + db.pool.is_available?(cnn).should be_false + end + db.pool.is_available?(cnn).should be_true + end + end +end diff --git a/lib/db/src/db.cr b/lib/db/src/db.cr new file mode 100644 index 00000000..915e7b81 --- /dev/null +++ b/lib/db/src/db.cr @@ -0,0 +1,200 @@ +require "uri" + +# The DB module is a unified interface for database access. +# Individual database systems are supported by specific database driver shards. +# +# Available drivers include: +# * [crystal-lang/crystal-sqlite3](https://github.com/crystal-lang/crystal-sqlite3) for SQLite +# * [crystal-lang/crystal-mysql](https://github.com/crystal-lang/crystal-mysql) for MySQL and MariaDB +# * [will/crystal-pg](https://github.com/will/crystal-pg) for PostgreSQL +# * [kaukas/crystal-cassandra](https://github.com/kaukas/crystal-cassandra) for Cassandra +# +# For basic instructions on implementing a new database driver, check `Driver` and the existing drivers. +# +# DB manages a connection pool. The connection pool can be configured by query parameters to the +# connection `URI` as described in `Database`. +# +# ### Usage +# +# Assuming `crystal-sqlite3` is included a SQLite3 database can be opened with `#open`. +# +# ``` +# db = DB.open "sqlite3:./path/to/db/file.db" +# db.close +# ``` +# +# If a block is given to `#open` the database is closed automatically +# +# ``` +# DB.open "sqlite3:./file.db" do |db| +# # work with db +# end # db is closed +# ``` +# +# In the code above `db` is a `Database`. Methods available for querying it are described in `QueryMethods`. +# +# Three kind of statements can be performed: +# 1. `Database#exec` waits no response from the database. +# 2. `Database#scalar` reads a single value of the response. +# 3. `Database#query` returns a ResultSet that allows iteration over the rows in the response and column information. +# +# All of the above methods allows parametrised query. Either positional or named arguments. +# +# Check a full working version: +# +# The following example uses SQLite where `?` indicates the arguments. If PostgreSQL is used `$1`, `$2`, etc. should be used. `crystal-db` does not interpret the statements. +# +# ``` +# require "db" +# require "sqlite3" +# +# DB.open "sqlite3:./file.db" do |db| +# # When using the pg driver, use $1, $2, etc. instead of ? +# db.exec "create table contacts (name text, age integer)" +# db.exec "insert into contacts values (?, ?)", "John Doe", 30 +# +# args = [] of DB::Any +# args << "Sarah" +# args << 33 +# db.exec "insert into contacts values (?, ?)", args +# +# puts "max age:" +# puts db.scalar "select max(age) from contacts" # => 33 +# +# puts "contacts:" +# db.query "select name, age from contacts order by age desc" do |rs| +# puts "#{rs.column_name(0)} (#{rs.column_name(1)})" +# # => name (age) +# rs.each do +# puts "#{rs.read(String)} (#{rs.read(Int32)})" +# # => Sarah (33) +# # => John Doe (30) +# end +# end +# end +# ``` +# +module DB + # Types supported to interface with database driver. + # These can be used in any `ResultSet#read` or any `Database#query` related + # method to be used as query parameters + TYPES = [Nil, String, Bool, Int32, Int64, Float32, Float64, Time, Bytes] + + # See `DB::TYPES` in `DB`. `Any` is a union of all types in `DB::TYPES` + {% begin %} + alias Any = Union({{*TYPES}}) + {% end %} + + # Result of a `#exec` statement. + record ExecResult, rows_affected : Int64, last_insert_id : Int64 + + # :nodoc: + def self.driver_class(driver_name) : Driver.class + drivers[driver_name]? || + raise(ArgumentError.new(%(no driver was registered for the schema "#{driver_name}", did you maybe forget to require the database driver?))) + end + + # Registers a driver class for a given *driver_name*. + # Should be called by drivers implementors only. + def self.register_driver(driver_name, driver_class : Driver.class) + drivers[driver_name] = driver_class + end + + private def self.drivers + @@drivers ||= {} of String => Driver.class + end + + # Creates a `Database` pool and opens initial connection(s) as specified in the connection *uri*. + # Use `DB#connect` to open a single connection. + # + # The scheme of the *uri* determines the driver to use. + # Connection parameters such as hostname, user, database name, etc. are specified according + # to each database driver's specific format. + # + # The returned database must be closed by `Database#close`. + def self.open(uri : URI | String) + build_database(uri) + end + + # Same as `#open` but the database is yielded and closed automatically at the end of the block. + def self.open(uri : URI | String, &block) + db = build_database(uri) + begin + yield db + ensure + db.close + end + end + + # Opens a connection using the specified *uri*. + # The scheme of the *uri* determines the driver to use. + # Returned connection must be closed by `Connection#close`. + # If a block is used the connection is yielded and closed automatically. + def self.connect(uri : URI | String) + build_connection(uri) + end + + # ditto + def self.connect(uri : URI | String, &block) + cnn = build_connection(uri) + begin + yield cnn + ensure + cnn.close + end + end + + private def self.build_database(connection_string : String) + build_database(URI.parse(connection_string)) + end + + private def self.build_database(uri : URI) + Database.new(build_driver(uri), uri) + end + + private def self.build_connection(connection_string : String) + build_connection(URI.parse(connection_string)) + end + + private def self.build_connection(uri : URI) + build_driver(uri).build_connection(SingleConnectionContext.new(uri)).as(Connection) + end + + private def self.build_driver(uri : URI) + driver_class(uri.scheme).new + end + + # :nodoc: + def self.fetch_bool(params : HTTP::Params, name, default : Bool) + case (value = params[name]?).try &.downcase + when nil + default + when "true" + true + when "false" + false + else + raise ArgumentError.new(%(invalid "#{value}" value for option "#{name}")) + end + end +end + +require "./db/pool" +require "./db/string_key_cache" +require "./db/query_methods" +require "./db/session_methods" +require "./db/disposable" +require "./db/driver" +require "./db/statement" +require "./db/begin_transaction" +require "./db/connection_context" +require "./db/connection" +require "./db/transaction" +require "./db/statement" +require "./db/pool_statement" +require "./db/database" +require "./db/pool_prepared_statement" +require "./db/pool_unprepared_statement" +require "./db/result_set" +require "./db/error" +require "./db/mapping" diff --git a/lib/db/src/db/begin_transaction.cr b/lib/db/src/db/begin_transaction.cr new file mode 100644 index 00000000..5fbe2d1c --- /dev/null +++ b/lib/db/src/db/begin_transaction.cr @@ -0,0 +1,33 @@ +module DB + module BeginTransaction + # Creates a transaction from the current context. + # If is expected that either `Transaction#commit` or `Transaction#rollback` + # are called explicitly to release the context. + abstract def begin_transaction : Transaction + + # yields a transaction from the current context. + # Query the database through `Transaction#connection` object. + # If an exception is thrown within the block a rollback is performed. + # The exception thrown is bubbled unless it is a `DB::Rollback`. + # From the yielded object `Transaction#commit` or `Transaction#rollback` + # can be called explicitly. + def transaction + tx = begin_transaction + begin + yield tx + rescue DB::Rollback + tx.rollback unless tx.closed? + rescue e + unless tx.closed? + # Ignore error in rollback. + # It would only be a secondary error to the original one, caused by + # corrupted connection state. + tx.rollback rescue nil + end + raise e + else + tx.commit unless tx.closed? + end + end + end +end diff --git a/lib/db/src/db/connection.cr b/lib/db/src/db/connection.cr new file mode 100644 index 00000000..91088254 --- /dev/null +++ b/lib/db/src/db/connection.cr @@ -0,0 +1,121 @@ +module DB + # Database driver implementors must subclass `Connection`. + # + # Represents one active connection to a database. + # + # Users should never instantiate a `Connection` manually. Use `DB#open` or `Database#connection`. + # + # Refer to `QueryMethods` for documentation about querying the database through this connection. + # + # ### Note to implementors + # + # The connection must be initialized in `#initialize` and closed in `#do_close`. + # + # Override `#build_prepared_statement` method in order to return a prepared `Statement` to allow querying. + # Override `#build_unprepared_statement` method in order to return a unprepared `Statement` to allow querying. + # See also `Statement` to define how the statements are executed. + # + # If at any give moment the connection is lost a DB::ConnectionLost should be raised. This will + # allow the connection pool to try to reconnect or use another connection if available. + # + abstract class Connection + include Disposable + include SessionMethods(Connection, Statement) + include BeginTransaction + + # :nodoc: + getter context + @statements_cache = StringKeyCache(Statement).new + @transaction = false + getter? prepared_statements : Bool + # :nodoc: + property auto_release : Bool = true + + def initialize(@context : ConnectionContext) + @prepared_statements = @context.prepared_statements? + end + + # :nodoc: + def fetch_or_build_prepared_statement(query) : Statement + @statements_cache.fetch(query) { build_prepared_statement(query) } + end + + # :nodoc: + abstract def build_prepared_statement(query) : Statement + + # :nodoc: + abstract def build_unprepared_statement(query) : Statement + + def begin_transaction : Transaction + raise DB::Error.new("There is an existing transaction in this connection") if @transaction + @transaction = true + create_transaction + end + + protected def create_transaction : Transaction + TopLevelTransaction.new(self) + end + + protected def do_close + @statements_cache.each_value &.close + @statements_cache.clear + @context.discard self + end + + # :nodoc: + protected def before_checkout + @auto_release = true + end + + # :nodoc: + protected def after_release + end + + # return this connection to the pool + # managed by the database. Should be used + # only if the connection was obtained by `Database#checkout`. + def release + @context.release(self) + end + + # :nodoc: + def release_from_statement + self.release if @auto_release && !@transaction + end + + # :nodoc: + def release_from_transaction + @transaction = false + end + + # :nodoc: + def perform_begin_transaction + self.unprepared.exec "BEGIN" + end + + # :nodoc: + def perform_commit_transaction + self.unprepared.exec "COMMIT" + end + + # :nodoc: + def perform_rollback_transaction + self.unprepared.exec "ROLLBACK" + end + + # :nodoc: + def perform_create_savepoint(name) + self.unprepared.exec "SAVEPOINT #{name}" + end + + # :nodoc: + def perform_release_savepoint(name) + self.unprepared.exec "RELEASE SAVEPOINT #{name}" + end + + # :nodoc: + def perform_rollback_savepoint(name) + self.unprepared.exec "ROLLBACK TO #{name}" + end + end +end diff --git a/lib/db/src/db/connection_context.cr b/lib/db/src/db/connection_context.cr new file mode 100644 index 00000000..31e81d83 --- /dev/null +++ b/lib/db/src/db/connection_context.cr @@ -0,0 +1,36 @@ +module DB + module ConnectionContext + # Returns the uri with the connection settings to the database + abstract def uri : URI + + # Return whether the statements should be prepared by default + abstract def prepared_statements? : Bool + + # Indicates that the *connection* was permanently closed + # and should not be used in the future. + abstract def discard(connection : Connection) + + # Indicates that the *connection* is no longer needed + # and can be reused in the future. + abstract def release(connection : Connection) + end + + # :nodoc: + class SingleConnectionContext + include ConnectionContext + + getter uri : URI + getter? prepared_statements : Bool + + def initialize(@uri : URI) + params = HTTP::Params.parse(uri.query || "") + @prepared_statements = DB.fetch_bool(params, "prepared_statements", true) + end + + def discard(connection : Connection) + end + + def release(connection : Connection) + end + end +end diff --git a/lib/db/src/db/database.cr b/lib/db/src/db/database.cr new file mode 100644 index 00000000..fd7a7092 --- /dev/null +++ b/lib/db/src/db/database.cr @@ -0,0 +1,146 @@ +require "http/params" +require "weak_ref" + +module DB + # Acts as an entry point for database access. + # Connections are managed by a pool. + # Use `DB#open` to create a `Database` instance. + # + # Refer to `QueryMethods` and `SessionMethods` for documentation about querying the database. + # + # ## Database URI + # + # Connection parameters are configured in a URI. The format is specified by the individual + # database drivers. See the [reference book](https://crystal-lang.org/reference/database/) for examples. + # + # The connection pool can be configured from URI parameters: + # + # - `initial_pool_size` (default 1) + # - `max_pool_size` (default 0 = unlimited) + # - `max_idle_pool_size` (default 1) + # - `checkout_timeout` (default 5.0) + # - `retry_attempts` (default 1) + # - `retry_delay` (in seconds, default 1.0) + # + # When querying a database, prepared statements are used by default. + # This can be changed from the `prepared_statements` URI parameter: + # + # - `prepared_statements` (true, or false, default true) + # + class Database + include SessionMethods(Database, PoolStatement) + include ConnectionContext + + # :nodoc: + getter driver + # :nodoc: + getter pool + + # Returns the uri with the connection settings to the database + getter uri : URI + + getter? prepared_statements : Bool + + @pool : Pool(Connection) + @setup_connection : Connection -> Nil + @statements_cache = StringKeyCache(PoolPreparedStatement).new + + # :nodoc: + def initialize(@driver : Driver, @uri : URI) + params = HTTP::Params.parse(uri.query || "") + @prepared_statements = DB.fetch_bool(params, "prepared_statements", true) + pool_options = @driver.connection_pool_options(params) + + @setup_connection = ->(conn : Connection) {} + @pool = uninitialized Pool(Connection) # in order to use self in the factory proc + @pool = Pool.new(**pool_options) { + conn = @driver.build_connection(self).as(Connection) + @setup_connection.call conn + conn + } + end + + def setup_connection(&proc : Connection -> Nil) + @setup_connection = proc + @pool.each_resource do |conn| + @setup_connection.call conn + end + end + + # Closes all connection to the database. + def close + @statements_cache.each_value &.close + @statements_cache.clear + + @pool.close + end + + # :nodoc: + def discard(connection : Connection) + @pool.delete connection + end + + # :nodoc: + def release(connection : Connection) + @pool.release connection + end + + # :nodoc: + def fetch_or_build_prepared_statement(query) : PoolStatement + @statements_cache.fetch(query) { build_prepared_statement(query) } + end + + # :nodoc: + def build_prepared_statement(query) : PoolStatement + PoolPreparedStatement.new(self, query) + end + + # :nodoc: + def build_unprepared_statement(query) : PoolStatement + PoolUnpreparedStatement.new(self, query) + end + + # :nodoc: + def checkout_some(candidates : Enumerable(WeakRef(Connection))) : {Connection, Bool} + @pool.checkout_some candidates + end + + # yields a connection from the pool + # the connection is returned to the pool + # when the block ends + def using_connection + connection = self.checkout + begin + yield connection + ensure + connection.release + end + end + + # returns a connection from the pool + # the returned connection must be returned + # to the pool by explictly calling `Connection#release` + def checkout + connection = @pool.checkout + connection.auto_release = false + connection + end + + # yields a `Transaction` from a connection of the pool + # Refer to `BeginTransaction#transaction` for documentation. + def transaction + using_connection do |cnn| + cnn.transaction do |tx| + yield tx + end + end + end + + # :nodoc: + def retry + @pool.retry do + yield + end + end + end +end diff --git a/lib/db/src/db/disposable.cr b/lib/db/src/db/disposable.cr new file mode 100644 index 00000000..10afb2b2 --- /dev/null +++ b/lib/db/src/db/disposable.cr @@ -0,0 +1,24 @@ +module DB + # Generic module to encapsulate disposable db resources. + module Disposable + macro included + @closed = false + end + + # Closes this object. + def close + return if @closed + do_close + @closed = true + end + + # Returns `true` if this object is closed. See `#close`. + def closed? + @closed + end + + # Implementors overrides this method to perform resource cleanup + # If an exception is raised, the resource will not be marked as closed. + protected abstract def do_close + end +end diff --git a/lib/db/src/db/driver.cr b/lib/db/src/db/driver.cr new file mode 100644 index 00000000..5df46a26 --- /dev/null +++ b/lib/db/src/db/driver.cr @@ -0,0 +1,42 @@ +module DB + # Database driver implementors must subclass `Driver`, + # register with a driver_name using `DB#register_driver` and + # override the factory method `#build_connection`. + # + # ``` + # require "db" + # + # class FakeDriver < DB::Driver + # def build_connection(context : DB::ConnectionContext) + # FakeConnection.new context + # end + # end + # + # DB.register_driver "fake", FakeDriver + # ``` + # + # Access to this fake datbase will be available with + # + # ``` + # DB.open "fake://..." do |db| + # # ... use db ... + # end + # ``` + # + # Refer to `Connection`, `Statement` and `ResultSet` for further + # driver implementation instructions. + abstract class Driver + abstract def build_connection(context : ConnectionContext) : Connection + + def connection_pool_options(params : HTTP::Params) + { + initial_pool_size: params.fetch("initial_pool_size", 1).to_i, + max_pool_size: params.fetch("max_pool_size", 0).to_i, + max_idle_pool_size: params.fetch("max_idle_pool_size", 1).to_i, + checkout_timeout: params.fetch("checkout_timeout", 5.0).to_f, + retry_attempts: params.fetch("retry_attempts", 1).to_i, + retry_delay: params.fetch("retry_delay", 1.0).to_f, + } + end + end +end diff --git a/lib/db/src/db/error.cr b/lib/db/src/db/error.cr new file mode 100644 index 00000000..c8a81f28 --- /dev/null +++ b/lib/db/src/db/error.cr @@ -0,0 +1,32 @@ +module DB + class Error < Exception + end + + class MappingException < Error + end + + class PoolTimeout < Error + end + + class PoolRetryAttemptsExceeded < Error + end + + # Raised when an established connection is lost + # probably due to socket/network issues. + # It is used by the connection pool retry logic. + class ConnectionLost < Error + getter connection : Connection + + def initialize(@connection) + end + end + + # Raised when a connection is unable to be established + # probably due to socket/network or configuration issues. + # It is used by the connection pool retry logic. + class ConnectionRefused < Error + end + + class Rollback < Error + end +end diff --git a/lib/db/src/db/mapping.cr b/lib/db/src/db/mapping.cr new file mode 100644 index 00000000..5543e501 --- /dev/null +++ b/lib/db/src/db/mapping.cr @@ -0,0 +1,154 @@ +module DB + # Empty module used for marking a class as supporting DB:Mapping + module Mappable; end + + # The `DB.mapping` macro defines how an object is built from a `ResultSet`. + # + # It takes hash literal as argument, in which attributes and types are defined. + # Once defined, `ResultSet#read(t)` populates properties of the class from the + # `ResultSet`. + # + # ```crystal + # require "db" + # + # class Employee + # DB.mapping({ + # title: String, + # name: String, + # }) + # end + # + # employees = Employee.from_rs(db.query("SELECT title, name FROM employees")) + # employees[0].title # => "Manager" + # employees[0].name # => "John" + # ``` + # + # Attributes not mapped with `DB.mapping` are not defined as properties. + # Also, missing attributes raise a `DB::MappingException`. + # + # You can also define attributes for each property. + # + # ```crystal + # class Employee + # DB.mapping({ + # title: String, + # name: { + # type: String, + # nilable: true, + # key: "firstname", + # }, + # }) + # end + # ``` + # + # Available attributes: + # + # * *type* (required) defines its type. In the example above, *title: String* is a shortcut to *title: {type: String}*. + # * *nilable* defines if a property can be a `Nil`. + # * **default**: value to use if the property is missing in the result set, or if it's `null` and `nilable` was not set to `true`. If the default value creates a new instance of an object (for example `[1, 2, 3]` or `SomeObject.new`), a different instance will be used each time a row is parsed. + # * *key* defines which column to read from a `ResultSet`. It defaults to the name of the property. + # * *converter* takes an alternate type for parsing. It requires a `#from_rs` method in that class, and returns an instance of the given type. + # + # The mapping also automatically defines Crystal properties (getters and setters) for each + # of the keys. It doesn't define a constructor accepting those arguments, but you can provide + # an overload. + # + # The macro basically defines a constructor accepting a `ResultSet` that reads from + # it and initializes this type's instance variables. + # + # This macro also declares instance variables of the types given in the mapping. + macro mapping(properties, strict = true) + include ::DB::Mappable + + {% for key, value in properties %} + {% properties[key] = {type: value} unless value.is_a?(HashLiteral) || value.is_a?(NamedTupleLiteral) %} + {% end %} + + {% for key, value in properties %} + {% value[:nilable] = true if value[:type].is_a?(Generic) && value[:type].type_vars.map(&.resolve).includes?(Nil) %} + + {% if value[:type].is_a?(Call) && value[:type].name == "|" && + (value[:type].receiver.resolve == Nil || value[:type].args.map(&.resolve).any?(&.==(Nil))) %} + {% value[:nilable] = true %} + {% end %} + {% end %} + + {% for key, value in properties %} + @{{key.id}} : {{value[:type]}} {{ (value[:nilable] ? "?" : "").id }} + + def {{key.id}}=(_{{key.id}} : {{value[:type]}} {{ (value[:nilable] ? "?" : "").id }}) + @{{key.id}} = _{{key.id}} + end + + def {{key.id}} + @{{key.id}} + end + {% end %} + + def self.from_rs(%rs : ::DB::ResultSet) + %objs = Array(self).new + %rs.each do + %objs << self.new(%rs) + end + %objs + ensure + %rs.close + end + + def initialize(%rs : ::DB::ResultSet) + {% for key, value in properties %} + %var{key.id} = nil + %found{key.id} = false + {% end %} + + %rs.each_column do |col_name| + case col_name + {% for key, value in properties %} + when {{value[:key] || key.id.stringify}} + %found{key.id} = true + %var{key.id} = + {% if value[:converter] %} + {{value[:converter]}}.from_rs(%rs) + {% elsif value[:nilable] || value[:default] != nil %} + %rs.read(::Union({{value[:type]}} | Nil)) + {% else %} + %rs.read({{value[:type]}}) + {% end %} + {% end %} + else + {% if strict %} + raise ::DB::MappingException.new("unknown result set attribute: #{col_name}") + {% else %} + %rs.read + {% end %} + end + end + + {% for key, value in properties %} + {% unless value[:nilable] || value[:default] != nil %} + if %var{key.id}.is_a?(Nil) && !%found{key.id} + raise ::DB::MappingException.new("missing result set attribute: {{(value[:key] || key).id}}") + end + {% end %} + {% end %} + + {% for key, value in properties %} + {% if value[:nilable] %} + {% if value[:default] != nil %} + @{{key.id}} = %found{key.id} ? %var{key.id} : {{value[:default]}} + {% else %} + @{{key.id}} = %var{key.id} + {% end %} + {% elsif value[:default] != nil %} + @{{key.id}} = %var{key.id}.is_a?(Nil) ? {{value[:default]}} : %var{key.id} + {% else %} + @{{key.id}} = %var{key.id}.as({{value[:type]}}) + {% end %} + {% end %} + end + end + + macro mapping(**properties) + ::DB.mapping({{properties}}) + end +end diff --git a/lib/db/src/db/pool.cr b/lib/db/src/db/pool.cr new file mode 100644 index 00000000..575c5125 --- /dev/null +++ b/lib/db/src/db/pool.cr @@ -0,0 +1,207 @@ +require "weak_ref" + +module DB + class Pool(T) + @initial_pool_size : Int32 + # maximum amount of objects in the pool. Either available or in use. + @max_pool_size : Int32 + @available = Set(T).new + @total = [] of T + @checkout_timeout : Float64 + # maximum amount of retry attempts to reconnect to the db. See `Pool#retry`. + @retry_attempts : Int32 + @retry_delay : Float64 + + def initialize(@initial_pool_size = 1, @max_pool_size = 0, @max_idle_pool_size = 1, @checkout_timeout = 5.0, + @retry_attempts = 1, @retry_delay = 0.2, &@factory : -> T) + @initial_pool_size.times { build_resource } + + @availability_channel = Channel(Nil).new + @waiting_resource = 0 + @mutex = Mutex.new + end + + # close all resources in the pool + def close : Nil + @total.each &.close + @total.clear + @available.clear + end + + def checkout : T + resource = if @available.empty? + if can_increase_pool + build_resource + else + wait_for_available + pick_available + end + else + pick_available + end + + @available.delete resource + resource.before_checkout + resource + end + + # ``` + # selected, is_candidate = pool.checkout_some(candidates) + # ``` + # `selected` be a resource from the `candidates` list and `is_candidate` == `true` + # or `selected` will be a new resource and `is_candidate` == `false` + def checkout_some(candidates : Enumerable(WeakRef(T))) : {T, Bool} + # TODO honor candidates while waiting for availables + # this will allow us to remove `candidates.includes?(resource)` + candidates.each do |ref| + resource = ref.value + if resource && is_available?(resource) + @available.delete resource + resource.before_checkout + return {resource, true} + end + end + + resource = checkout + {resource, candidates.any? { |ref| ref.value == resource }} + end + + def release(resource : T) : Nil + if can_increase_idle_pool + @available << resource + resource.after_release + @availability_channel.send nil if are_waiting_for_resource? + else + resource.close + @total.delete(resource) + end + end + + # :nodoc: + # Will retry the block if a `ConnectionLost` exception is thrown. + # It will try to reuse all of the available connection right away, + # but if a new connection is needed there is a `retry_delay` seconds delay. + def retry + current_available = @available.size + # if the pool hasn't reach the max size, allow 1 attempt + # to make a new connection if needed without sleeping + current_available += 1 if can_increase_pool + + (current_available + @retry_attempts).times do |i| + begin + sleep @retry_delay if i >= current_available + return yield + rescue e : ConnectionLost + # if the connection is lost close it to release resources + # and remove it from the known pool. + delete(e.connection) + e.connection.close + rescue e : ConnectionRefused + # a ConnectionRefused means a new connection + # was intended to be created + # nothing to due but to retry soon + end + end + raise PoolRetryAttemptsExceeded.new + end + + # :nodoc: + def each_resource + @available.each do |resource| + yield resource + end + end + + # :nodoc: + def is_available?(resource : T) + @available.includes?(resource) + end + + # :nodoc: + def delete(resource : T) + @total.delete(resource) + @available.delete(resource) + end + + private def build_resource : T + resource = @factory.call + @total << resource + @available << resource + resource + end + + private def can_increase_pool + @max_pool_size == 0 || @total.size < @max_pool_size + end + + private def can_increase_idle_pool + @available.size < @max_idle_pool_size + end + + private def pick_available + @available.first + end + + private def wait_for_available + timeout = TimeoutHelper.new(@checkout_timeout.to_f64) + inc_waiting_resource + + timeout.start + + # TODO update to select keyword for crystal 0.19 + index, _ = Channel.select(@availability_channel.receive_select_action, timeout.receive_select_action) + case index + when 0 + timeout.cancel + dec_waiting_resource + when 1 + dec_waiting_resource + raise DB::PoolTimeout.new + else + raise DB::Error.new + end + end + + private def inc_waiting_resource + @mutex.synchronize do + @waiting_resource += 1 + end + end + + private def dec_waiting_resource + @mutex.synchronize do + @waiting_resource -= 1 + end + end + + private def are_waiting_for_resource? + @mutex.synchronize do + @waiting_resource > 0 + end + end + + class TimeoutHelper + def initialize(@timeout : Float64) + @abort_timeout = false + @timeout_channel = Channel(Nil).new + end + + def receive_select_action + @timeout_channel.receive_select_action + end + + def start + spawn do + sleep @timeout + unless @abort_timeout + @timeout_channel.send nil + end + end + end + + def cancel + @abort_timeout = true + end + end + end +end diff --git a/lib/db/src/db/pool_prepared_statement.cr b/lib/db/src/db/pool_prepared_statement.cr new file mode 100644 index 00000000..b91ee1a2 --- /dev/null +++ b/lib/db/src/db/pool_prepared_statement.cr @@ -0,0 +1,56 @@ +module DB + # Represents a statement to be executed in any of the connections + # of the pool. The statement is not be executed in a prepared fashion. + # The execution of the statement is retried according to the pool configuration. + # + # See `PoolStatement` + class PoolPreparedStatement < PoolStatement + # connections where the statement was prepared + @connections = Set(WeakRef(Connection)).new + + def initialize(db : Database, query : String) + super + # Prepares a statement on some connection + # otherwise the preparation is delayed until the first execution. + # After the first initialization the connection must be released + # it will be checked out when executing it. + statement_with_retry &.release_connection + # TODO use a round-robin selection in the pool so multiple sequentially + # initialized statements are assigned to different connections. + end + + protected def do_close + # TODO close all statements on all connections. + # currently statements are closed when the connection is closed. + + # WHAT-IF the connection is busy? Should each statement be able to + # deallocate itself when the connection is free. + @connections.clear + end + + # builds a statement over a real connection + # the connection is registered in `@connections` + private def build_statement : Statement + clean_connections + conn, existing = @db.checkout_some(@connections) + begin + stmt = conn.prepared.build(@query) + rescue ex + conn.release + raise ex + end + @connections << WeakRef.new(conn) unless existing + stmt + end + + private def clean_connections + # remove disposed or closed connections + @connections.each do |ref| + conn = ref.value + if !conn || conn.closed? + @connections.delete ref + end + end + end + end +end diff --git a/lib/db/src/db/pool_statement.cr b/lib/db/src/db/pool_statement.cr new file mode 100644 index 00000000..668ce2b6 --- /dev/null +++ b/lib/db/src/db/pool_statement.cr @@ -0,0 +1,57 @@ +module DB + # When a statement is to be executed in a DB that has a connection pool + # a statement from the DB needs to be able to represent a statement in any + # of the connections of the pool. Otherwise the user will need to deal with + # actual connections in some point. + abstract class PoolStatement + include StatementMethods + + def initialize(@db : Database, @query : String) + end + + # See `QueryMethods#exec` + def exec : ExecResult + statement_with_retry &.exec + end + + # See `QueryMethods#exec` + def exec(*args) : ExecResult + statement_with_retry &.exec(*args) + end + + # See `QueryMethods#exec` + def exec(args : Array) : ExecResult + statement_with_retry &.exec(args) + end + + # See `QueryMethods#query` + def query : ResultSet + statement_with_retry &.query + end + + # See `QueryMethods#query` + def query(*args) : ResultSet + statement_with_retry &.query(*args) + end + + # See `QueryMethods#query` + def query(args : Array) : ResultSet + statement_with_retry &.query(args) + end + + # See `QueryMethods#scalar` + def scalar(*args) + statement_with_retry &.scalar(*args) + end + + # builds a statement over a real connection + # the conneciton is registered in `@connections` + private abstract def build_statement : Statement + + private def statement_with_retry + @db.retry do + return yield build_statement + end + end + end +end diff --git a/lib/db/src/db/pool_unprepared_statement.cr b/lib/db/src/db/pool_unprepared_statement.cr new file mode 100644 index 00000000..c58fafd9 --- /dev/null +++ b/lib/db/src/db/pool_unprepared_statement.cr @@ -0,0 +1,27 @@ +module DB + # Represents a statement to be executed in any of the connections + # of the pool. The statement is not be executed in a non prepared fashion. + # The execution of the statement is retried according to the pool configuration. + # + # See `PoolStatement` + class PoolUnpreparedStatement < PoolStatement + def initialize(db : Database, query : String) + super + end + + protected def do_close + # unprepared statements do not need to be release in each connection + end + + # builds a statement over a real connection + private def build_statement : Statement + conn = @db.pool.checkout + begin + conn.unprepared.build(@query) + rescue ex + conn.release + raise ex + end + end + end +end diff --git a/lib/db/src/db/query_methods.cr b/lib/db/src/db/query_methods.cr new file mode 100644 index 00000000..9676256c --- /dev/null +++ b/lib/db/src/db/query_methods.cr @@ -0,0 +1,275 @@ +module DB + # Methods to allow querying a database. + # All methods accepts a `query : String` and a set arguments. + # + # Three kind of statements can be performed: + # 1. `#exec` waits no record response from the database. An `ExecResult` is returned. + # 2. `#scalar` reads a single value of the response. A union of possible values is returned. + # 3. `#query` returns a `ResultSet` that allows iteration over the rows in the response and column information. + # + # Arguments can be passed by position + # + # ``` + # db.query("SELECT name FROM ... WHERE age > ?", age) + # ``` + # + # Convention of mapping how arguments are mapped to the query depends on each driver. + # + # Including `QueryMethods` requires a `build(query) : Statement` method that is not expected + # to be called directly. + module QueryMethods(Stmt) + # :nodoc: + abstract def build(query) : Stmt + + # Executes a *query* and returns a `ResultSet` with the results. + # The `ResultSet` must be closed manually. + # + # ``` + # result = db.query "select name from contacts where id = ?", 10 + # begin + # if result.move_next + # id = result.read(Int32) + # end + # ensure + # result.close + # end + # ``` + def query(query, *args) + build(query).query(*args) + end + + # Executes a *query* and yields a `ResultSet` with the results. + # The `ResultSet` is closed automatically. + # + # ``` + # db.query("select name from contacts where age > ?", 18) do |rs| + # rs.each do + # name = rs.read(String) + # end + # end + # ``` + def query(query, *args) + # CHECK build(query).query(*args, &block) + rs = query(query, *args) + yield rs ensure rs.close + end + + # Executes a *query* that expects a single row and yields a `ResultSet` + # positioned at that first row. + # + # The given block must not invoke `move_next` on the yielded result set. + # + # Raises `DB::Error` if there were no rows, or if there were more than one row. + # + # ``` + # name = db.query_one "select name from contacts where id = ?", 18, &.read(String) + # ``` + def query_one(query, *args, &block : ResultSet -> U) : U forall U + query(query, *args) do |rs| + raise DB::Error.new("no rows") unless rs.move_next + + value = yield rs + raise DB::Error.new("more than one row") if rs.move_next + return value + end + end + + # Executes a *query* that expects a single row and returns it + # as a tuple of the given *types*. + # + # Raises `DB::Error` if there were no rows, or if there were more than one row. + # + # ``` + # db.query_one "select name, age from contacts where id = ?", 1, as: {String, Int32} + # ``` + def query_one(query, *args, as types : Tuple) + query_one(query, *args) do |rs| + rs.read(*types) + end + end + + # Executes a *query* that expects a single row and returns it + # as a named tuple of the given *types* (the keys of the named tuple + # are not necessarily the column names). + # + # Raises `DB::Error` if there were no rows, or if there were more than one row. + # + # ``` + # db.query_one "select name, age from contacts where id = ?", 1, as: {name: String, age: Int32} + # ``` + def query_one(query, *args, as types : NamedTuple) + query_one(query, *args) do |rs| + rs.read(**types) + end + end + + # Executes a *query* that expects a single row + # and returns the first column's value as the given *type*. + # + # Raises `DB::Error` if there were no rows, or if there were more than one row. + # + # ``` + # db.query_one "select name from contacts where id = ?", 1, as: String + # ``` + def query_one(query, *args, as type : Class) + query_one(query, *args) do |rs| + rs.read(type) + end + end + + # Executes a *query* that expects at most a single row and yields a `ResultSet` + # positioned at that first row. + # + # Returns `nil`, not invoking the block, if there were no rows. + # + # Raises `DB::Error` if there were more than one row + # (this ends up invoking the block once). + # + # ``` + # name = db.query_one? "select name from contacts where id = ?", 18, &.read(String) + # typeof(name) # => String | Nil + # ``` + def query_one?(query, *args, &block : ResultSet -> U) : U? forall U + query(query, *args) do |rs| + return nil unless rs.move_next + + value = yield rs + raise DB::Error.new("more than one row") if rs.move_next + return value + end + end + + # Executes a *query* that expects a single row and returns it + # as a tuple of the given *types*. + # + # Returns `nil` if there were no rows. + # + # Raises `DB::Error` if there were more than one row. + # + # ``` + # result = db.query_one? "select name, age from contacts where id = ?", 1, as: {String, Int32} + # typeof(result) # => Tuple(String, Int32) | Nil + # ``` + def query_one?(query, *args, as types : Tuple) + query_one?(query, *args) do |rs| + rs.read(*types) + end + end + + # Executes a *query* that expects a single row and returns it + # as a named tuple of the given *types* (the keys of the named tuple + # are not necessarily the column names). + # + # Returns `nil` if there were no rows. + # + # Raises `DB::Error` if there were more than one row. + # + # ``` + # result = db.query_one? "select name, age from contacts where id = ?", 1, as: {age: String, name: Int32} + # typeof(result) # => NamedTuple(age: String, name: Int32) | Nil + # ``` + def query_one?(query, *args, as types : NamedTuple) + query_one?(query, *args) do |rs| + rs.read(**types) + end + end + + # Executes a *query* that expects a single row + # and returns the first column's value as the given *type*. + # + # Returns `nil` if there were no rows. + # + # Raises `DB::Error` if there were more than one row. + # + # ``` + # name = db.query_one? "select name from contacts where id = ?", 1, as: String + # typeof(name) # => String? + # ``` + def query_one?(query, *args, as type : Class) + query_one?(query, *args) do |rs| + rs.read(type) + end + end + + # Executes a *query* and yield a `ResultSet` positioned at the beginning + # of each row, returning an array of the values of the blocks. + # + # ``` + # names = db.query_all "select name from contacts", &.read(String) + # ``` + def query_all(query, *args, &block : ResultSet -> U) : Array(U) forall U + ary = [] of U + query_each(query, *args) do |rs| + ary.push(yield rs) + end + ary + end + + # Executes a *query* and returns an array where each row is + # read as a tuple of the given *types*. + # + # ``` + # contacts = db.query_all "select name, age from contacts", as: {String, Int32} + # ``` + def query_all(query, *args, as types : Tuple) + query_all(query, *args) do |rs| + rs.read(*types) + end + end + + # Executes a *query* and returns an array where each row is + # read as a named tuple of the given *types* (the keys of the named tuple + # are not necessarily the column names). + # + # ``` + # contacts = db.query_all "select name, age from contacts", as: {name: String, age: Int32} + # ``` + def query_all(query, *args, as types : NamedTuple) + query_all(query, *args) do |rs| + rs.read(**types) + end + end + + # Executes a *query* and returns an array where the + # value of each row is read as the given *type*. + # + # ``` + # names = db.query_all "select name from contacts", as: String + # ``` + def query_all(query, *args, as type : Class) + query_all(query, *args) do |rs| + rs.read(type) + end + end + + # Executes a *query* and yields the `ResultSet` once per each row. + # The `ResultSet` is closed automatically. + # + # ``` + # db.query_each "select name from contacts" do |rs| + # puts rs.read(String) + # end + # ``` + def query_each(query, *args) + query(query, *args) do |rs| + rs.each do + yield rs + end + end + end + + # Performs the `query` and returns an `ExecResult` + def exec(query, *args) + build(query).exec(*args) + end + + # Performs the `query` and returns a single scalar value + # + # ``` + # puts db.scalar("SELECT MAX(name)").as(String) # => (a String) + # ``` + def scalar(query, *args) + build(query).scalar(*args) + end + end +end diff --git a/lib/db/src/db/result_set.cr b/lib/db/src/db/result_set.cr new file mode 100644 index 00000000..b2bd7222 --- /dev/null +++ b/lib/db/src/db/result_set.cr @@ -0,0 +1,125 @@ +module DB + # The response of a query performed on a `Database`. + # + # See `DB` for a complete sample. + # + # Each `#read` call consumes the result and moves to the next column. + # Each column must be read in order. + # At any moment a `#move_next` can be invoked, meaning to skip the + # remaining, or even all the columns, in the current row. + # Also it is not mandatory to consume the whole `ResultSet`, hence an iteration + # through `#each` or `#move_next` can be stopped. + # + # **Note:** depending on how the `ResultSet` was obtained it might be mandatory an + # explicit call to `#close`. Check `QueryMethods#query`. + # + # ### Note to implementors + # + # 1. Override `#move_next` to move to the next row. + # 2. Override `#read` returning the next value in the row. + # 3. (Optional) Override `#read(t)` for some types `t` for which custom logic other than a simple cast is needed. + # 4. Override `#column_count`, `#column_name`. + abstract class ResultSet + include Disposable + + # :nodoc: + getter statement + + def initialize(@statement : DB::Statement) + end + + protected def do_close + statement.release_connection + end + + # TODO add_next_result_set : Bool + + # Iterates over all the rows + def each + while move_next + yield + end + end + + # Iterates over all the columns + def each_column + column_count.times do |x| + yield column_name(x) + end + end + + # Move the next row in the result. + # Return `false` if no more rows are available. + # See `#each` + abstract def move_next : Bool + + # TODO def empty? : Bool, handle internally with move_next (?) + + # Returns the number of columns in the result + abstract def column_count : Int32 + + # Returns the name of the column in `index` 0-based position. + abstract def column_name(index : Int32) : String + + # Returns the name of the columns. + def column_names + Array(String).new(column_count) { |i| column_name(i) } + end + + # Reads the next column value + abstract def read + + # Reads the next columns and maps them to a class + def read(type : DB::Mappable.class) + type.new(self) + end + + # Reads the next column value as a **type** + def read(type : T.class) : T forall T + value = read + if value.is_a?(T) + value + else + raise "#{self.class}#read returned a #{value.class}. A #{T} was expected." + end + end + + # Reads the next columns and returns a tuple of the values. + def read(*types : Class) + internal_read(*types) + end + + # Reads the next columns and returns a named tuple of the values. + def read(**types : Class) + internal_read(**types) + end + + private def internal_read(*types : *T) forall T + {% begin %} + Tuple.new( + {% for type in T %} + read({{type.instance}}), + {% end %} + ) + {% end %} + end + + private def internal_read(**types : **T) forall T + {% begin %} + NamedTuple.new( + {% for name, type in T %} + {{ name }}: read({{type.instance}}), + {% end %} + ) + {% end %} + end + + # def read_blob + # yield ... io .... + # end + + # def read_text + # yield ... io .... + # end + end +end diff --git a/lib/db/src/db/session_methods.cr b/lib/db/src/db/session_methods.cr new file mode 100644 index 00000000..1e566398 --- /dev/null +++ b/lib/db/src/db/session_methods.cr @@ -0,0 +1,73 @@ +module DB + # Methods that are shared accross session like objects: + # - Database + # - Connection + # + # Classes that includes this module are able to execute + # queries and statements in both prepared and unprepared fashion. + # + # This module serves for dsl reuse over session like objects. + module SessionMethods(Session, Stmt) + include QueryMethods(Stmt) + + # Returns whether by default the statements should + # be prepared or not. + abstract def prepared_statements? : Bool + + abstract def fetch_or_build_prepared_statement(query) : Stmt + + abstract def build_unprepared_statement(query) : Stmt + + def build(query) : Stmt + if prepared_statements? + fetch_or_build_prepared_statement(query) + else + build_unprepared_statement(query) + end + end + + # dsl helper to build prepared statements + # returns a value that includes `QueryMethods` + def prepared + PreparedQuery(Session, Stmt).new(self) + end + + # Returns a prepared `Statement` that has not been executed yet. + def prepared(query) + prepared.build(query) + end + + # dsl helper to build unprepared statements + # returns a value that includes `QueryMethods` + def unprepared + UnpreparedQuery(Session, Stmt).new(self) + end + + # Returns an unprepared `Statement` that has not been executed yet. + def unprepared(query) + unprepared.build(query) + end + + struct PreparedQuery(Session, Stmt) + include QueryMethods(Stmt) + + def initialize(@session : Session) + end + + def build(query) : Stmt + @session.fetch_or_build_prepared_statement(query) + end + end + + struct UnpreparedQuery(Session, Stmt) + include QueryMethods(Stmt) + + def initialize(@session : Session) + end + + def build(query) : Stmt + @session.build_unprepared_statement(query) + end + end + end +end diff --git a/lib/db/src/db/statement.cr b/lib/db/src/db/statement.cr new file mode 100644 index 00000000..0ab4be35 --- /dev/null +++ b/lib/db/src/db/statement.cr @@ -0,0 +1,114 @@ +module DB + # Common interface for connection based statements + # and for connection pool statements. + module StatementMethods + include Disposable + + protected def do_close + end + + # See `QueryMethods#scalar` + def scalar(*args) + query(*args) do |rs| + rs.each do + return rs.read + end + end + + raise "no results" + end + + # See `QueryMethods#query` + def query(*args) + rs = query(*args) + yield rs ensure rs.close + end + + # See `QueryMethods#exec` + abstract def exec : ExecResult + # See `QueryMethods#exec` + abstract def exec(*args) : ExecResult + # See `QueryMethods#exec` + abstract def exec(args : Array) : ExecResult + + # See `QueryMethods#query` + abstract def query : ResultSet + # See `QueryMethods#query` + abstract def query(*args) : ResultSet + # See `QueryMethods#query` + abstract def query(args : Array) : ResultSet + end + + # Represents a query in a `Connection`. + # It should be created by `QueryMethods`. + # + # ### Note to implementors + # + # 1. Subclass `Statements` + # 2. `Statements` are created from a custom driver `Connection#prepare` method. + # 3. `#perform_query` executes a query that is expected to return a `ResultSet` + # 4. `#perform_exec` executes a query that is expected to return an `ExecResult` + # 6. `#do_close` is called to release the statement resources. + abstract class Statement + include StatementMethods + + # :nodoc: + getter connection + + def initialize(@connection : Connection) + end + + def release_connection + @connection.release_from_statement + end + + # See `QueryMethods#exec` + def exec : DB::ExecResult + perform_exec_and_release(Slice(Any).empty) + end + + # See `QueryMethods#exec` + def exec(args : Array) : DB::ExecResult + perform_exec_and_release(args) + end + + # See `QueryMethods#exec` + def exec(*args) + # TODO better way to do it + perform_exec_and_release(args) + end + + # See `QueryMethods#query` + def query : DB::ResultSet + perform_query_with_rescue Tuple.new + end + + # See `QueryMethods#query` + def query(args : Array) : DB::ResultSet + perform_query_with_rescue args + end + + # See `QueryMethods#query` + def query(*args) + perform_query_with_rescue args + end + + private def perform_exec_and_release(args : Enumerable) : ExecResult + return perform_exec(args) + ensure + release_connection + end + + private def perform_query_with_rescue(args : Enumerable) : ResultSet + return perform_query(args) + rescue e : Exception + # Release connection only when an exception occurs during the query + # execution since we need the connection open while the ResultSet is open + release_connection + raise e + end + + protected abstract def perform_query(args : Enumerable) : ResultSet + protected abstract def perform_exec(args : Enumerable) : ExecResult + end +end diff --git a/lib/db/src/db/string_key_cache.cr b/lib/db/src/db/string_key_cache.cr new file mode 100644 index 00000000..f2cae629 --- /dev/null +++ b/lib/db/src/db/string_key_cache.cr @@ -0,0 +1,21 @@ +module DB + class StringKeyCache(T) + @cache = {} of String => T + + def fetch(key : String) : T + value = @cache.fetch(key, nil) + value = @cache[key] = yield unless value + value + end + + def each_value + @cache.each do |_, value| + yield value + end + end + + def clear + @cache.clear + end + end +end diff --git a/lib/db/src/db/transaction.cr b/lib/db/src/db/transaction.cr new file mode 100644 index 00000000..75837af5 --- /dev/null +++ b/lib/db/src/db/transaction.cr @@ -0,0 +1,131 @@ +module DB + # Transactions should be started from `DB#transaction`, `Connection#transaction` + # or `Connection#begin_transaction`. + # + # Use `Transaction#connection` to submit statements to the database. + # + # Use `Transaction#commit` or `Transaction#rollback` to close the ongoing transaction + # explicitly. Or refer to `BeginTransaction#transaction` for documentation on how to + # use `#transaction(&block)` methods in `DB` and `Connection`. + # + # Nested transactions are supported by using sql `SAVEPOINT`. To start a nested + # transaction use `Transaction#transaction` or `Transaction#begin_transaction`. + # + abstract class Transaction + include Disposable + include BeginTransaction + + abstract def connection : Connection + + # commits the current transaction + def commit + close! + end + + # rollbacks the current transaction + def rollback + close! + end + + private def close! + raise DB::Error.new("Transaction already closed") if closed? + close + end + + abstract def release_from_nested_transaction + end + + class TopLevelTransaction < Transaction + getter connection : Connection + # :nodoc: + property savepoint_name : String? = nil + + def initialize(@connection : Connection) + @nested_transaction = false + @connection.perform_begin_transaction + end + + def commit + @connection.perform_commit_transaction + super + end + + def rollback + @connection.perform_rollback_transaction + super + end + + protected def do_close + connection.release_from_transaction + end + + def begin_transaction : Transaction + raise DB::Error.new("There is an existing nested transaction in this transaction") if @nested_transaction + @nested_transaction = true + create_save_point_transaction(self) + end + + # :nodoc: + def create_save_point_transaction(parent : Transaction) : SavePointTransaction + # TODO should we wrap this in a mutex? + previous_savepoint = @savepoint_name + savepoint_name = if previous_savepoint + previous_savepoint.succ + else + # random prefix to avoid determinism + "cr_#{@connection.object_id}_#{Random.rand(10_000)}_00001" + end + + @savepoint_name = savepoint_name + + create_save_point_transaction(parent, savepoint_name) + end + + protected def create_save_point_transaction(parent : Transaction, savepoint_name : String) : SavePointTransaction + SavePointTransaction.new(parent, savepoint_name) + end + + # :nodoc: + def release_from_nested_transaction + @nested_transaction = false + end + end + + class SavePointTransaction < Transaction + getter connection : Connection + + def initialize(@parent : Transaction, @savepoint_name : String) + @nested_transaction = false + @connection = @parent.connection + @connection.perform_create_savepoint(@savepoint_name) + end + + def commit + @connection.perform_release_savepoint(@savepoint_name) + super + end + + def rollback + @connection.perform_rollback_savepoint(@savepoint_name) + super + end + + protected def do_close + @parent.release_from_nested_transaction + end + + def begin_transaction : Transaction + raise DB::Error.new("There is an existing nested transaction in this transaction") if @nested_transaction + @nested_transaction = true + create_save_point_transaction(self) + end + + def create_save_point_transaction(parent : Transaction) + @parent.create_save_point_transaction(parent) + end + + def release_from_nested_transaction + @nested_transaction = false + end + end +end diff --git a/lib/db/src/db/version.cr b/lib/db/src/db/version.cr new file mode 100644 index 00000000..6660da24 --- /dev/null +++ b/lib/db/src/db/version.cr @@ -0,0 +1,3 @@ +module DB + VERSION = "0.6.0" +end diff --git a/lib/db/src/spec.cr b/lib/db/src/spec.cr new file mode 100644 index 00000000..05990586 --- /dev/null +++ b/lib/db/src/spec.cr @@ -0,0 +1,514 @@ +require "spec" + +private def assert_single_read(rs, value_type, value) + rs.move_next.should be_true + rs.read(value_type).should eq(value) + rs.move_next.should be_false +end + +module DB + # Helper class to ensure behaviour of custom drivers + # + # ``` + # require "db/spec" + # + # DB::DriverSpecs(DB::Any).run do + # # How to connect to database + # connection_string "scheme://database_url" + # + # # Clean up database if needed using before/after callbacks + # before do + # # ... + # end + # + # after do + # # ... + # end + # + # # Sample values that will be stored, retrieved across many specs + # sample_value "hello", "varchar(25)", "'hello'" + # + # it "custom spec with a db initialized" do |db| + # # assert something using *db* + # end + # + # # Configure the appropiate syntax for different commands needed to run the specs + # binding_syntax do |index| + # "?" + # end + # + # create_table_1column_syntax do |table_name, col1| + # "create table #{table_name} (#{col1.name} #{col1.sql_type} #{col1.null ? "NULL" : "NOT NULL"})" + # end + # end + # ``` + # + # The following methods needs to be called to configure the appropiate syntax + # for different commands and allow all the specs to run: `binding_syntax`, `create_table_1column_syntax`, + # `create_table_2columns_syntax`, `select_1column_syntax`, `select_2columns_syntax`, `select_count_syntax`, + # `select_scalar_syntax`, `insert_1column_syntax`, `insert_2columns_syntax`, `drop_table_if_exists_syntax`. + # + class DriverSpecs(DBAnyType) + record ColumnDef, name : String, sql_type : String, null : Bool + + @before : Proc(Nil) = ->{} + @after : Proc(Nil) = ->{} + @encode_null = "NULL" + @support_prepared = true + @support_unprepared = true + + def before(&@before : -> Nil) + end + + def after(&@after : -> Nil) + end + + def encode_null(@encode_null : String) + end + + # Allow specs that uses prepared statements (default `true`) + def support_prepared(@support_prepared : Bool) + end + + # :nodoc: + def support_prepared + @support_prepared + end + + # Allow specs that uses unprepared statements (default `true`) + def support_unprepared(@support_unprepared : Bool) + end + + # :nodoc: + def support_unprepared + @support_unprepared + end + + # :nodoc: + macro db_spec_config(name, *, block = false) + {% if name.is_a?(TypeDeclaration) %} + @{{name.var.id}} : {{name.type}}? + + {% if block %} + def {{name.var.id}}(&@{{name.var.id}} : {{name.type}}) + end + {% else %} + def {{name.var.id}}(@{{name.var.id}} : {{name.type}}) + end + {% end %} + + # :nodoc: + def {{name.var.id}} + res = @{{name.var.id}} + raise "Missing {{name.var.id}} to setup db" unless res + res + end + {% end %} + end + + db_spec_config connection_string : String + db_spec_config binding_syntax : Proc(Int32, String), block: true + db_spec_config select_scalar_syntax : Proc(String, String?, String), block: true + db_spec_config create_table_1column_syntax : Proc(String, ColumnDef, String), block: true + db_spec_config create_table_2columns_syntax : Proc(String, ColumnDef, ColumnDef, String), block: true + db_spec_config insert_1column_syntax : Proc(String, ColumnDef, String, String), block: true + db_spec_config insert_2columns_syntax : Proc(String, ColumnDef, String, ColumnDef, String, String), block: true + db_spec_config select_1column_syntax : Proc(String, ColumnDef, String), block: true + db_spec_config select_2columns_syntax : Proc(String, ColumnDef, ColumnDef, String), block: true + db_spec_config select_count_syntax : Proc(String, String), block: true + db_spec_config drop_table_if_exists_syntax : Proc(String, String), block: true + + # :nodoc: + record SpecIt, description : String, prepared : Symbol, file : String, line : Int32, end_line : Int32, block : DB::Database -> Nil + getter its = [] of SpecIt + + def it(description = "assert", prepared = :default, file = __FILE__, line = __LINE__, end_line = __END_LINE__, &block : DB::Database ->) + return unless Spec.matches?(description, file, line, end_line) + @its << SpecIt.new(description, prepared, file, line, end_line, block) + end + + # :nodoc: + record ValueDef(T), value : T, sql_type : String, value_encoded : String + + @values = [] of ValueDef(DBAnyType) + + # Use *value* as sample value that should be stored in columns of type *sql_type*. + # *value_encoded* is driver specific expression that should generate that value in the database. + # *type_safe_value* indicates whether *value_encoded* is expected to generate the *value* even without + # been stored in a table (default `true`). + def sample_value(value, sql_type, value_encoded, *, type_safe_value = true) + @values << ValueDef(DBAnyType).new(value, sql_type, value_encoded) + + it "select nil as (#{typeof(value)} | Nil)", prepared: :both do |db| + db.query select_scalar(encode_null, nil) do |rs| + assert_single_read rs, typeof(value || nil), nil + end + end + + value_desc = value.to_s + value_desc = "#{value_desc[0..25]}...(#{value_desc.size})" if value_desc.size > 25 + value_desc = "#{value_desc} as #{sql_type}" + + if type_safe_value + it "executes with bind #{value_desc}" do |db| + db.scalar(select_scalar(param(1), sql_type), value).should eq(value) + end + + it "executes with bind #{value_desc} as array" do |db| + db.scalar(select_scalar(param(1), sql_type), [value]).should eq(value) + end + + it "select #{value_desc} as literal" do |db| + db.scalar(select_scalar(value_encoded, sql_type)).should eq(value) + + db.query select_scalar(value_encoded, sql_type) do |rs| + assert_single_read rs, typeof(value), value + end + end + end + + it "insert/get value #{value_desc} from table", prepared: :both do |db| + db.exec sql_create_table_table1(c1 = col1(sql_type)) + db.exec sql_insert_table1(c1, value_encoded) + + db.query_one(sql_select_table1(c1), as: typeof(value)).should eq(value) + end + + it "insert/get value #{value_desc} from table as nillable", prepared: :both do |db| + db.exec sql_create_table_table1(c1 = col1(sql_type)) + db.exec sql_insert_table1(c1, value_encoded) + + db.query_one(sql_select_table1(c1), as: ::Union(typeof(value) | Nil)).should eq(value) + end + + it "insert/get value nil from table as nillable #{sql_type}", prepared: :both do |db| + db.exec sql_create_table_table1(c1 = col1(sql_type, null: true)) + db.exec sql_insert_table1(c1, encode_null) + + db.query_one(sql_select_table1(c1), as: ::Union(typeof(value) | Nil)).should eq(nil) + end + + it "insert/get value #{value_desc} from table with binding" do |db| + db.exec sql_create_table_table2(c1 = col1(sql_type_for(String)), c2 = col2(sql_type)) + # the next statement will force a union in the *args + db.exec sql_insert_table2(c1, param(1), c2, param(2)), value_for(String), value + db.query_one(sql_select_table2(c2), as: typeof(value)).should eq(value) + end + + it "insert/get value #{value_desc} from table as nillable with binding" do |db| + db.exec sql_create_table_table2(c1 = col1(sql_type_for(String)), c2 = col2(sql_type)) + # the next statement will force a union in the *args + db.exec sql_insert_table2(c1, param(1), c2, param(2)), value_for(String), value + db.query_one(sql_select_table2(c2), as: ::Union(typeof(value) | Nil)).should eq(value) + end + + it "insert/get value nil from table as nillable #{sql_type} with binding" do |db| + db.exec sql_create_table_table2(c1 = col1(sql_type_for(String)), c2 = col2(sql_type, null: true)) + db.exec sql_insert_table2(c1, param(1), c2, param(2)), value_for(String), nil + + db.query_one(sql_select_table2(c2), as: ::Union(typeof(value) | Nil)).should eq(nil) + end + + it "can use read(#{typeof(value)}) with DB::ResultSet", prepared: :both do |db| + db.exec sql_create_table_table1(c1 = col1(sql_type)) + db.exec sql_insert_table1(c1, value_encoded) + db.query(sql_select_table1(c1)) do |rs| + assert_single_read rs.as(DB::ResultSet), typeof(value), value + end + end + + it "can use read(#{typeof(value)}?) with DB::ResultSet", prepared: :both do |db| + db.exec sql_create_table_table1(c1 = col1(sql_type)) + db.exec sql_insert_table1(c1, value_encoded) + db.query(sql_select_table1(c1)) do |rs| + assert_single_read rs.as(DB::ResultSet), ::Union(typeof(value) | Nil), value + end + end + + it "can use read(#{typeof(value)}?) with DB::ResultSet for nil", prepared: :both do |db| + db.exec sql_create_table_table1(c1 = col1(sql_type, null: true)) + db.exec sql_insert_table1(c1, encode_null) + db.query(sql_select_table1(c1)) do |rs| + assert_single_read rs.as(DB::ResultSet), ::Union(typeof(value) | Nil), nil + end + end + end + + # :nodoc: + def include_shared_specs + it "connects using connection_string" do |db| + db.is_a?(DB::Database) + end + + it "can create direct connection" do + DB.connect(connection_string) do |cnn| + cnn.is_a?(DB::Connection) + cnn.scalar(select_scalar(encode_null, nil)).should be_nil + end + end + + it "binds nil" do |db| + # PG is unable to perform this query without a type annotation + db.scalar(select_scalar(param(1), sql_type_for(String)), nil).should be_nil + end + + it "selects nil as scalar", prepared: :both do |db| + db.scalar(select_scalar(encode_null, nil)).should be_nil + end + + it "gets column count", prepared: :both do |db| + db.exec sql_create_table_person + db.query "select * from person" do |rs| + rs.column_count.should eq(2) + end + end + + it "gets column name", prepared: :both do |db| + db.exec sql_create_table_person + + db.query "select name, age from person" do |rs| + rs.column_name(0).should eq("name") + rs.column_name(1).should eq("age") + end + end + + it "gets many rows from table" do |db| + db.exec sql_create_table_person + db.exec sql_insert_person, "foo", 10 + db.exec sql_insert_person, "bar", 20 + db.exec sql_insert_person, "baz", 30 + + names = [] of String + ages = [] of Int32 + db.query sql_select_person do |rs| + rs.each do + names << rs.read(String) + ages << rs.read(Int32) + end + end + names.should eq(["foo", "bar", "baz"]) + ages.should eq([10, 20, 30]) + end + + # describe "transactions" do + it "transactions: can read inside transaction and rollback after" do |db| + db.exec sql_create_table_person + db.transaction do |tx| + tx.connection.scalar(sql_select_count_person).should eq(0) + tx.connection.exec sql_insert_person, "John Doe", 10 + tx.connection.scalar(sql_select_count_person).should eq(1) + tx.rollback + end + db.scalar(sql_select_count_person).should eq(0) + end + + it "transactions: can read inside transaction or after commit" do |db| + db.exec sql_create_table_person + db.transaction do |tx| + tx.connection.scalar(sql_select_count_person).should eq(0) + tx.connection.exec sql_insert_person, "John Doe", 10 + tx.connection.scalar(sql_select_count_person).should eq(1) + # using other connection + db.scalar(sql_select_count_person).should eq(0) + end + db.scalar("select count(*) from person").should eq(1) + end + # end + + # describe "nested transactions" do + it "nested transactions: can read inside transaction and rollback after" do |db| + db.exec sql_create_table_person + db.transaction do |tx_0| + tx_0.connection.scalar(sql_select_count_person).should eq(0) + tx_0.connection.exec sql_insert_person, "John Doe", 10 + tx_0.transaction do |tx_1| + tx_1.connection.exec sql_insert_person, "Sarah", 11 + tx_1.connection.scalar(sql_select_count_person).should eq(2) + tx_1.transaction do |tx_2| + tx_2.connection.exec sql_insert_person, "Jimmy", 12 + tx_2.connection.scalar(sql_select_count_person).should eq(3) + tx_2.rollback + end + end + tx_0.connection.scalar(sql_select_count_person).should eq(2) + tx_0.rollback + end + db.scalar(sql_select_count_person).should eq(0) + end + # end + end + + # :nodoc: + def with_db(options = nil) + @before.call + DB.open("#{connection_string}#{"?#{options}" if options}") do |db| + db.exec(sql_drop_table("table1")) + db.exec(sql_drop_table("table2")) + db.exec(sql_drop_table("person")) + yield db + end + ensure + @after.call + end + + # :nodoc: + def select_scalar(expression, sql_type) + select_scalar_syntax.call(expression, sql_type) + end + + # :nodoc: + def param(index) + binding_syntax.call(index) + end + + # :nodoc: + def encode_null + @encode_null + end + + # :nodoc: + def sql_type_for(a_class) + value = @values.select { |v| v.value.class == a_class }.first? + if value + value.sql_type + else + raise "missing sample_value with #{a_class}" + end + end + + # :nodoc: + macro value_for(a_class) + _value_for({{a_class}}).as({{a_class}}) + end + + # :nodoc: + def _value_for(a_class) + value = @values.select { |v| v.value.class == a_class }.first? + if value + value.value + else + raise "missing sample_value with #{a_class}" + end + end + + # :nodoc: + def col_name + ColumnDef.new("name", sql_type_for(String), false) + end + + # :nodoc: + def col_age + ColumnDef.new("age", sql_type_for(Int32), false) + end + + # :nodoc: + def sql_create_table_person + create_table_2columns_syntax.call("person", col_name, col_age) + end + + # :nodoc: + def sql_select_person + select_2columns_syntax.call("person", col_name, col_age) + end + + # :nodoc: + def sql_insert_person + insert_2columns_syntax.call("person", col_name, param(1), col_age, param(2)) + end + + # :nodoc: + def sql_select_count_person + select_count_syntax.call("person") + end + + # :nodoc: + def col1(sql_type, *, null = false) + ColumnDef.new("col1", sql_type, null) + end + + # :nodoc: + def col2(sql_type, *, null = false) + ColumnDef.new("col2", sql_type, null) + end + + # :nodoc: + def sql_create_table_table1(col : ColumnDef) + create_table_1column_syntax.call("table1", col) + end + + # :nodoc: + def sql_create_table_table2(col1 : ColumnDef, col2 : ColumnDef) + create_table_2columns_syntax.call("table2", col1, col2) + end + + # :nodoc: + def sql_insert_table1(col1 : ColumnDef, expression) + insert_1column_syntax.call("table1", col1, expression) + end + + # :nodoc: + def sql_insert_table2(col1 : ColumnDef, expr1, col2 : ColumnDef, expr2) + insert_2columns_syntax.call("table2", col1, expr1, col2, expr2) + end + + # :nodoc: + def sql_select_table1(col : ColumnDef) + select_1column_syntax.call("table1", col) + end + + # :nodoc: + def sql_select_table2(col : ColumnDef) + select_1column_syntax.call("table2", col) + end + + # :nodoc: + def sql_drop_table(table_name) + drop_table_if_exists_syntax.call(table_name) + end + + def self.run(description = "as a db") + ctx = self.new + with ctx yield + + describe description do + ctx.include_shared_specs + + ctx.its.each do |db_it| + case db_it.prepared + when :default + it(db_it.description, db_it.file, db_it.line, db_it.end_line) do + ctx.with_db do |db| + db_it.block.call db + nil + end + end + when :both + values = [] of Bool + values << true if ctx.support_prepared + values << false if ctx.support_unprepared + case values.size + when 0 + raise "Neither prepared non unprepared statements allowed" + when 1 + it(db_it.description, db_it.file, db_it.line, db_it.end_line) do + ctx.with_db do |db| + db_it.block.call db + nil + end + end + else + values.each do |prepared_statements| + it("#{db_it.description} (prepared_statements=#{prepared_statements})", db_it.file, db_it.line, db_it.end_line) do + ctx.with_db "prepared_statements=#{prepared_statements}" do |db| + db_it.block.call db + nil + end + end + end + end + end + end + end + end + end +end diff --git a/lib/exception_page/.editorconfig b/lib/exception_page/.editorconfig new file mode 100644 index 00000000..163eb75c --- /dev/null +++ b/lib/exception_page/.editorconfig @@ -0,0 +1,9 @@ +root = true + +[*.cr] +charset = utf-8 +end_of_line = lf +insert_final_newline = true +indent_style = space +indent_size = 2 +trim_trailing_whitespace = true diff --git a/lib/exception_page/.gitignore b/lib/exception_page/.gitignore new file mode 100644 index 00000000..e29dae78 --- /dev/null +++ b/lib/exception_page/.gitignore @@ -0,0 +1,9 @@ +/docs/ +/lib/ +/bin/ +/.shards/ +*.dwarf + +# Libraries don't need dependency lock +# Dependencies will be locked in application that uses them +/shard.lock diff --git a/lib/exception_page/.travis.yml b/lib/exception_page/.travis.yml new file mode 100644 index 00000000..1b186392 --- /dev/null +++ b/lib/exception_page/.travis.yml @@ -0,0 +1,13 @@ +language: crystal +addons: + chrome: stable +before_install: + # Setup chromedriver for LuckyFlow + - sudo apt-get install chromium-chromedriver + - sudo ln -s /usr/lib/chromium-browser/chromedriver /usr/bin/chromedriver + - "export DISPLAY=:99.0" + - "sh -e /etc/init.d/xvfb start" + - sleep 3 # give xvfb some time to start +script: + - crystal spec + - crystal tool format spec src --check diff --git a/lib/exception_page/LICENSE b/lib/exception_page/LICENSE new file mode 100644 index 00000000..41e89166 --- /dev/null +++ b/lib/exception_page/LICENSE @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2018 Paul Smith + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/lib/exception_page/README.md b/lib/exception_page/README.md new file mode 100644 index 00000000..e069ea37 --- /dev/null +++ b/lib/exception_page/README.md @@ -0,0 +1,111 @@ +# Exception Page + +A library for displaying exceptional exception pages for easier debugging. + +![screen shot 2018-06-29 at 2 39 18 pm](https://user-images.githubusercontent.com/22394/42109073-6e767d06-7baa-11e8-9ec9-0a2afce605be.png) + +## Installation + +Add this to your application's `shard.yml`: + +```yaml +dependencies: + exception_page: + github: crystal-loot/exception_page +``` + +## Usage + +Require the shard: + +```crystal +require "exception_page" +``` + +Create an exception page: + +```crystal +class MyApp::ExceptionPage < ExceptionPage + def styles + ExceptionPage::Styles.new( + accent: "purple", # Choose the HTML color value. Can be hex + ) + end +end +``` + +Render the HTML when an exception occurs: + +```crystal +class MyErrorHandler + include HTTP::Handler + + def call_next(context) + begin + # Normally you'd call some code to handle the request + # We're hard-coding an error here to show you how to use the lib. + raise SomeError.new("Something went wrong") + rescue e + context.response.status_code = 500 + context.response.print MyApp::ExceptionPage.for_runtime_exception(context, e).to_s + end + end +``` + +## Customizing the page + +```crystal +class MyApp::ExceptionPage < ExceptionPage + def styles + ExceptionPage::Styles.new( + accent: "purple", # Required + highlight: "gray", # Optional + flash_highlight: "red", # Optional + logo_uri: "base64_encoded_data_uri" # Optional. Defaults to Crystal logo. Generate a logo here: https://dopiaza.org/tools/datauri/index.php + ) + end + + # Optional. If provided, clicking the logo will open this page + def project_url + "https://myproject.com" + end + + # Optional + def stack_trace_heading_html + <<-HTML + Say hi + HTML + end + + # Optional + def extra_javascript + <<-JAVASCRIPT + window.sayHi = function() { + alert("Say Hi!"); + } + JAVASCRIPT + end +end +``` + +## Development + +TODO: Write development instructions here + +## Contributing + +1. Fork it () +2. Create your feature branch (`git checkout -b my-new-feature`) +3. Commit your changes (`git commit -am 'Add some feature'`) +4. Push to the branch (`git push origin my-new-feature`) +5. Create a new Pull Request + +## Contributors + +- [@paulcsmith](https://github.com/paulcsmith) Paul Smith +- [@faustinoaq](https://github.com/faustinoaq) Faustino Aigular - Wrote the initial [Amber PR adding exception pages](https://github.com/amberframework/amber/pull/864) + +## Special Thanks + +This exception page is heavily based on the [Phoenix error page](https://github.com/phoenixframework/phoenix/issues/1776) +by [@rstacruz](https://github.com/rstacruz). Thanks to the Phoenix team and @rstacruz! diff --git a/lib/exception_page/shard.yml b/lib/exception_page/shard.yml new file mode 100644 index 00000000..a69315a1 --- /dev/null +++ b/lib/exception_page/shard.yml @@ -0,0 +1,15 @@ +name: exception_page +version: 0.1.2 + +authors: + - Paul Smith + - Faustino Aguilar + +development_dependencies: + lucky_flow: + github: luckyframework/lucky_flow + version: ~> 0.2.0 + +crystal: 0.25.0 + +license: MIT diff --git a/lib/exception_page/spec/exception_page_spec.cr b/lib/exception_page/spec/exception_page_spec.cr new file mode 100644 index 00000000..88532ab1 --- /dev/null +++ b/lib/exception_page/spec/exception_page_spec.cr @@ -0,0 +1,33 @@ +require "./spec_helper" + +describe ExceptionPage do + it "allows debugging the exception page" do + flow = ErrorDebuggingFlow.new + + flow.view_error_page + flow.should_have_information_for_debugging + flow.show_all_frames + flow.should_be_able_to_view_other_frames + end +end + +class ErrorDebuggingFlow < LuckyFlow + def view_error_page + visit "/" + end + + def should_have_information_for_debugging + el("@exception-title", text: "Something went very wrong").should be_on_page + el("@code-frames", text: "test_server.cr").should be_on_page + el("@code-preview").should be_on_page + end + + def show_all_frames + el("@show-all-frames").click + end + + def should_be_able_to_view_other_frames + el("@code-frame-file", "request_processor.cr").click + el("@code-frame-summary", text: "request_processor.cr").should be_on_page + end +end diff --git a/lib/exception_page/spec/frame_spec.cr b/lib/exception_page/spec/frame_spec.cr new file mode 100644 index 00000000..5a47be00 --- /dev/null +++ b/lib/exception_page/spec/frame_spec.cr @@ -0,0 +1,27 @@ +require "./spec_helper" + +describe "Frame parsing" do + it "returns the correct label" do + frame = frame_for("from usr/crystal-lang/frame_spec.cr:6:7 in '->'") + frame.label.should eq("crystal") + + frame = frame_for("from usr/crystal/frame_spec.cr:6:7 in '->'") + frame.label.should eq("crystal") + + frame = frame_for("from lib/exception_page/spec/frame_spec.cr:6:7 in '->'") + frame.label.should eq("exception_page") + + frame = frame_for("from lib/exception_page/frame_spec.cr:6:7 in '->'") + frame.label.should eq("exception_page") + + frame = frame_for("from lib/frame_spec.cr:6:7 in '->'") + frame.label.should eq("app") + + frame = frame_for("from src/frame_spec.cr:6:7 in '->'") + frame.label.should eq("app") + end +end + +private def frame_for(backtrace_line) + ExceptionPage::FrameGenerator.generate_frames(backtrace_line).first +end diff --git a/lib/exception_page/spec/spec_helper.cr b/lib/exception_page/spec/spec_helper.cr new file mode 100644 index 00000000..aec99871 --- /dev/null +++ b/lib/exception_page/spec/spec_helper.cr @@ -0,0 +1,25 @@ +require "spec" +require "lucky_flow" +require "http" +require "../src/exception_page" +require "./support/**" + +include LuckyFlow::Expectations + +server = TestServer.new(3002) + +LuckyFlow.configure do |settings| + settings.base_uri = "http://localhost:3002" + settings.stop_retrying_after = 40.milliseconds +end + +spawn do + server.listen +end + +at_exit do + LuckyFlow.shutdown + server.close +end + +Habitat.raise_if_missing_settings! diff --git a/lib/exception_page/spec/support/app_exception_page.cr b/lib/exception_page/spec/support/app_exception_page.cr new file mode 100644 index 00000000..3c8215c7 --- /dev/null +++ b/lib/exception_page/spec/support/app_exception_page.cr @@ -0,0 +1,19 @@ +class MyApp::ExceptionPage < ExceptionPage + def styles + Styles.new(accent: "purple") + end + + def stack_trace_heading_html + <<-HTML + Say hi + HTML + end + + def extra_javascript + <<-JAVASCRIPT + window.sayHi = function() { + alert("Say Hi!"); + } + JAVASCRIPT + end +end diff --git a/lib/exception_page/spec/support/test_server.cr b/lib/exception_page/spec/support/test_server.cr new file mode 100644 index 00000000..cdf3dba4 --- /dev/null +++ b/lib/exception_page/spec/support/test_server.cr @@ -0,0 +1,22 @@ +class TestServer + delegate listen, close, to: @server + + def initialize(port : Int32) + @server = HTTP::Server.new do |context| + if context.request.resource == "/favicon.ico" + context.response.print "" + else + begin + raise CustomException.new("Something went very wrong") + rescue e : CustomException + context.response.content_type = "text/html" + context.response.print MyApp::ExceptionPage.for_runtime_exception(context, e).to_s + end + end + end + @server.bind_tcp port: port + end +end + +class CustomException < Exception +end diff --git a/lib/exception_page/src/exception_page.cr b/lib/exception_page/src/exception_page.cr new file mode 100644 index 00000000..6bbf5039 --- /dev/null +++ b/lib/exception_page/src/exception_page.cr @@ -0,0 +1,54 @@ +abstract class ExceptionPage +end + +require "ecr" +require "./exception_page/*" + +# :nodoc: +abstract class ExceptionPage + @params : Hash(String, String) + @headers : Hash(String, Array(String)) + @session : Hash(String, HTTP::Cookie) + @method : String + @path : String + @message : String + @query : String + @frames = [] of Frame + @title : String + + abstract def styles : Styles + + # Add an optional link to your project + def project_url : String? + nil + end + + # Override this method to add extra HTML to the top of the stack trace heading + def stack_trace_heading_html + "" + end + + # Override this method to add extra javascript to the page + def extra_javascript + "" + end + + # :nodoc: + def initialize(context : HTTP::Server::Context, @message, @title, @frames) + @params = context.request.query_params.to_h + @headers = context.response.headers.to_h + @method = context.request.method + @path = context.request.path + @url = "#{context.request.host_with_port}#{context.request.path}" + @query = context.request.query_params.to_s + @session = context.response.cookies.to_h + end + + def self.for_runtime_exception(context : HTTP::Server::Context, ex : Exception) + title = "Error #{context.response.status_code}" + frames = FrameGenerator.generate_frames(ex.inspect_with_backtrace) + new(context, ex.message.to_s, title: title, frames: frames) + end + + ECR.def_to_s "#{__DIR__}/exception_page/exception_page.ecr" +end diff --git a/lib/exception_page/src/exception_page/exception_page.ecr b/lib/exception_page/src/exception_page/exception_page.ecr new file mode 100644 index 00000000..80ebee37 --- /dev/null +++ b/lib/exception_page/src/exception_page/exception_page.ecr @@ -0,0 +1,855 @@ + +<%- +monospace_font = "menlo, consolas, monospace" +-%> + + + + <% + details = @message.split('\n') + headline = details.first + %> + + <%= @title %> at <%= @method %> <%= @path %> - <%= headline %> + + + + + +
+ <%- if project_url -%> + + <%- else %> + + <%- end %> +
+
+ <%= @title %> + at <%= @method %> <%= @path %> +
+

<%= HTML.escape(headline).gsub("'", '\'').gsub(""", '"') %>

+
+ + See raw message + +
<%- details.each do |detail| -%><%= HTML.escape(detail).gsub("'", '\'').gsub(""", '"') %>
+<%- end -%>
+
+
+
+ <% if !@frames.empty? %> +
+
+ <% @frames.each do |frame| %> +
+ + + <%- if !frame.snippets.empty? -%> +
+                        <%- frame.snippets.each do |snippet| -%>
+                          <%= snippet.line %><%= HTML.escape(snippet.code.rstrip).gsub("'", '\'').gsub(""", '"') %>
+                        <%- end -%>
+                      
+ <%- else -%> +
No code available.
+ <%- end -%> + + <% if !frame.args.blank? %> +
+ + <%= frame.label %> + <%= frame.filename %> + +
<%= HTML.escape(frame.args).gsub("'", '\'').gsub(""", '"') %>
+
+ <% else %> +
+
+ <%= frame.label %> + <%= frame.filename %> +
+
+ <% end %> +
+ <% end %> +
+ +
+
+ <%= stack_trace_heading_html %> + +
+ +
    + <% @frames.each do |frame| %> +
  • + +
  • + <% end %> +
+
+
+ <% end %> +
+ +
+ <% if @params && !@params.empty? %> +
+ Params + <% @params.each do |key, value| %> +
+
<%= key %>
+
<%= value.inspect %>
+
+ <% end %> +
+ <% end %> + +
+ Request info + +
+
URI:
+
<%= @url %>
+
+ +
+
Query string:
+
<%= @query %>
+
+
+ +
+ Headers + <% @headers.each do |key, value| %> +
+
<%= key %>
+
<%= value %>
+
+ <% end %> +
+ + <% if (session = @session) && !session.empty? %> +
+ Session + <% session.each do |key, value| %> +
+
<%= key %>
+
<%= value.inspect %>
+
+ <% end %> +
+ <% end %> +
+ + + + diff --git a/lib/exception_page/src/exception_page/frame.cr b/lib/exception_page/src/exception_page/frame.cr new file mode 100644 index 00000000..693fbcda --- /dev/null +++ b/lib/exception_page/src/exception_page/frame.cr @@ -0,0 +1,77 @@ +# :nodoc: +struct ExceptionPage::Frame + property index : Int32, raw_frame : Regex::MatchData + + def initialize(@raw_frame, @index) + end + + def snippets : Array(Snippet) + snippets = [] of Snippet + if File.exists?(file) + lines = File.read_lines(file) + lines.each_with_index do |code, code_index| + if line_is_nearby?(code_index) + highlight = (code_index + 1 == line) ? true : false + snippets << Snippet.new( + line: code_index + 1, + code: code, + highlight: highlight + ) + end + end + end + snippets + end + + private def line_is_nearby?(code_index : Int32) + (code_index + 1) <= (line + 5) && (code_index + 1) >= (line - 5) + end + + def file : String + raw_frame[1] + end + + def filename : String + file.split('/').last + end + + def line : Int32 + raw_frame[2].to_i + end + + def args + "#{file}:#{line}#{column_with_surrounding_method_name}" + end + + private def column_with_surrounding_method_name + raw_frame[3] + end + + def label : String + case file + when .includes?("/crystal/"), .includes?("/crystal-lang/") + "crystal" + when /lib\/(?[^\/]+)\/.+/ + $~["name"] + else + "app" + end + end + + def context : String + if label == "app" + "app" + else + "all" + end + end + + struct Snippet + property line : Int32, + code : String, + highlight : Bool + + def initialize(@line, @code, @highlight) + end + end +end diff --git a/lib/exception_page/src/exception_page/frame_generator.cr b/lib/exception_page/src/exception_page/frame_generator.cr new file mode 100644 index 00000000..c731b59d --- /dev/null +++ b/lib/exception_page/src/exception_page/frame_generator.cr @@ -0,0 +1,13 @@ +# :nodoc: +class ExceptionPage::FrameGenerator + def self.generate_frames(message) + generated_frames = [] of Frame + if raw_frames = message.scan(/\s([^\s\:]+):(\d+)([^\n]+)/) + raw_frames.each_with_index do |frame, index| + generated_frames << Frame.new(raw_frame: frame, index: index) + end + end + + generated_frames + end +end diff --git a/lib/exception_page/src/exception_page/styles.cr b/lib/exception_page/src/exception_page/styles.cr new file mode 100644 index 00000000..e2554f87 --- /dev/null +++ b/lib/exception_page/src/exception_page/styles.cr @@ -0,0 +1,18 @@ +class ExceptionPage::Styles + getter accent : String, + highlight : String, + flash_highlight : String, + logo_uri : String? + + def initialize( + @accent, + @highlight = "#e5e5e5", + @flash_highlight = "#ffdc93", + @logo_uri = crystal_logo + ) + end + + private def crystal_logo + "" + end +end diff --git a/lib/exception_page/src/exception_page/version.cr b/lib/exception_page/src/exception_page/version.cr new file mode 100644 index 00000000..21b4f344 --- /dev/null +++ b/lib/exception_page/src/exception_page/version.cr @@ -0,0 +1,3 @@ +class ExceptionPage + VERSION = "0.1.2" +end diff --git a/lib/kemal/.ameba.yml b/lib/kemal/.ameba.yml new file mode 100644 index 00000000..cf5dbb27 --- /dev/null +++ b/lib/kemal/.ameba.yml @@ -0,0 +1,42 @@ +# This configuration file was generated by `ameba --gen-config` +# on 2019-06-14 15:05:57 UTC using Ameba version 0.10.0. +# The point is for the user to remove these configuration records +# one by one as the reported problems are removed from the code base. + +# Problems found: 7 +# Run `ameba --only Lint/UselessAssign` for details +Lint/UselessAssign: + Description: Disallows useless variable assignments + Enabled: true + Severity: Warning + Excluded: + - spec/view_spec.cr + +# Problems found: 1 +# Run `ameba --only Lint/ShadowingOuterLocalVar` for details +Lint/ShadowingOuterLocalVar: + Description: Disallows the usage of the same name as outer local variables for block + or proc arguments. + Enabled: true + Severity: Warning + Excluded: + - spec/run_spec.cr + +# Problems found: 1 +# Run `ameba --only Style/NegatedConditionsInUnless` for details +Style/NegatedConditionsInUnless: + Description: Disallows negated conditions in unless + Enabled: true + Severity: Convention + Excluded: + - src/kemal/ext/response.cr + +# Problems found: 1 +# Run `ameba --only Metrics/CyclomaticComplexity` for details +Metrics/CyclomaticComplexity: + Description: Disallows methods with a cyclomatic complexity higher than `MaxComplexity` + MaxComplexity: 10 + Enabled: true + Severity: Convention + Excluded: + - src/kemal/static_file_handler.cr diff --git a/lib/kemal/.github/FUNDING.yml b/lib/kemal/.github/FUNDING.yml new file mode 100644 index 00000000..3452e4af --- /dev/null +++ b/lib/kemal/.github/FUNDING.yml @@ -0,0 +1,8 @@ +# These are supported funding model platforms + +github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] +patreon: sdogruyol +open_collective: # Replace with a single Open Collective username +ko_fi: # Replace with a single Ko-fi username +tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel +custom: # Replace with a single custom sponsorship URL diff --git a/lib/kemal/.github/ISSUE_TEMPLATE.md b/lib/kemal/.github/ISSUE_TEMPLATE.md new file mode 100644 index 00000000..42a81d92 --- /dev/null +++ b/lib/kemal/.github/ISSUE_TEMPLATE.md @@ -0,0 +1,23 @@ +### Description + +[Description of the issue] + +### Steps to Reproduce + +1. [First Step] +2. [Second Step] +3. [and so on...] + +**Expected behavior:** [What you expect to happen] + +**Actual behavior:** [What actually happens] + +**Reproduces how often:** [What percentage of the time does it reproduce?] + +### Versions + +You can get this information from copy and pasting the output of `crystal --version`.Also, please include the OS and what version of the OS you're running. + +### Additional Information + +Any additional information, configuration or data that might be necessary to reproduce the issue. diff --git a/lib/kemal/.github/PULL_REQUEST_TEMPLATE.md b/lib/kemal/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 00000000..c97ebc98 --- /dev/null +++ b/lib/kemal/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,15 @@ +### Description of the Change + + + +### Alternate Designs + + + +### Benefits + + + +### Possible Drawbacks + + diff --git a/lib/kemal/.gitignore b/lib/kemal/.gitignore new file mode 100644 index 00000000..6c364e44 --- /dev/null +++ b/lib/kemal/.gitignore @@ -0,0 +1,8 @@ +/lib/ +/.crystal/ +/.shards/ +*.log +/bin/ +# Libraries don't need dependency lock +# Dependencies will be locked in application that uses them +/shard.lock \ No newline at end of file diff --git a/lib/kemal/.travis.yml b/lib/kemal/.travis.yml new file mode 100644 index 00000000..9883f5e9 --- /dev/null +++ b/lib/kemal/.travis.yml @@ -0,0 +1,14 @@ +language: crystal +crystal: + - latest + - nightly + +script: + - crystal spec + - crystal spec --release --no-debug + - crystal tool format --check + - bin/ameba src + +matrix: + allow_failures: + - crystal: nightly diff --git a/lib/kemal/CHANGELOG.md b/lib/kemal/CHANGELOG.md new file mode 100644 index 00000000..24955695 --- /dev/null +++ b/lib/kemal/CHANGELOG.md @@ -0,0 +1,370 @@ +# 0.26.0 (05-08-2019) + +- Crystal 0.30.0 support :tada: [#548](https://github.com/kemalcr/kemal/pull/548) and [#544](https://github.com/kemalcr/kemal/pull/544). Thanks @bcardiff and @straight-shoota :pray: +- Add support for serving files greater than 2^31 bytes [#546](https://github.com/kemalcr/kemal/pull/546). Thanks @omarroth :pray: +- Properly measure request time using `Time.monotonic` [#527](https://github.com/kemalcr/kemal/pull/527). Thanks @spinscale :pray: + +# 0.25.2 (08-02-2019) + +- Add option to config to parse or not command line parameters [#483](https://github.com/kemalcr/kemal/pull/483). Thanks @diegogub :pray: + +- Allow to set filename for `send_file` [#512](https://github.com/kemalcr/kemal/pull/512). Thanks @mamantoha :pray: + + +```ruby +send_file env, "./asset/image.jpeg", filename: "image.jpg" +``` + +- Set `status_code` before response [#513](https://github.com/kemalcr/kemal/pull/513). Thanks @mamantohoa :pray: + +- Use Crystal MIME registry. [#516](https://github.com/kemalcr/kemal/pull/516) Thanks @Sija :pray: + +# 0.25.1 (06-10-2018) + +- Fix `params.files` memoization https://github.com/kemalcr/kemal/pull/503. Thanks @mamantoha :pray: + +# 0.25.0 (05-10-2018) + +- Crystal 0.27.0 support. +- *[breaking change]* Added back `env.params.files`. + +Here's a fully working sample for reading a image file upload `image1` and saving it under `public/uploads`. + +```crystal +post "/upload" do |env| + file = env.params.files["image1"].tempfile + file_path = ::File.join [Kemal.config.public_folder, "uploads/", File.basename(file.path)] + File.open(file_path, "w") do |f| + IO.copy(file, f) + end + "Upload ok" +end +``` + +To test + +`curl -F "image1=@/Users/serdar/Downloads/kemal.png" http://localhost:3000/upload` + +- Cache HTTP routes to increase performance :rocket: https://github.com/kemalcr/kemal/pull/493 + +# 0.24.0 (14-08-2018) + +- *[breaking change]* Removed `env.params.files`. You can use Crystal's built-in `HTTP::FormData.parse` instead + +```ruby +post "/upload" do |env| + HTTP::FormData.parse(env.request) do |upload| + filename = file.filename + + if !filename.is_a?(String) + "No filename included in upload" + else + file_path = ::File.join [Kemal.config.public_folder, "uploads/", filename] + File.open(file_path, "w") do |f| + IO.copy(file.tmpfile, f) + end + "Upload OK" + end +end +``` + +- *[breaking change]* From now on to access dynamic url params in a WebSocket route you have to use: + +```ruby +ws "/:id" do |socket, context| + id = context.ws_route_lookup.params["id"] +end +``` + +- *[breaking change]* Removed `_method` magic param. + +- Added new exception page [#466](https://github.com/kemalcr/kemal/pull/466). Thanks @mamantoha 🙏 + +- Support custom port binding. Thanks @straight-shoota 🙏 + +```ruby +Kemal.run do |config| + server = config.server.not_nil! + server.bind_tcp "127.0.0.1", 3000, reuse_port: true + server.bind_tcp "0.0.0.0", 3001, reuse_port: true +end +``` + +# 0.23.0 (17-06-2018) + +- Crystal 0.25.0 support 🎉 +- Add `Kemal::Context.get?` to safely access context storage :sunglasses: +- [Security] Don't serve 404 image dynamically :thumbsup: +- Disable `X-Powered-By` header [#449](https://github.com/kemalcr/kemal/pull/449). Thanks @Blacksmoke16 🙏 + +# 0.22.0 (29-12-2017) + +- Crystal 0.24.1 support 🎉 +- Only return string from route.[#408](https://github.com/kemalcr/kemal/pull/408) thanks @crisward 🙏 +- Don't crash on empty path when compiled in --release. [#407](https://github.com/kemalcr/kemal/pull/407) thanks @crisward 🙏 +- Rename `Kemal::CommonLogHandler` to `Kemal::LogHandler` and `Kemal::CommonExceptionHandler` to `Kemal::ExceptionHandler`. +- Allow videos to be opened with correct mime type. [#406](https://github.com/kemalcr/kemal/pull/406) thanks @crisward 🙏 +- Add webm mime type.[#413](https://github.com/kemalcr/kemal/pull/413) thanks @reindeer-cafe 🙏 + + +# 0.21.0 (05-09-2017) + +- Dynamically insert handlers :muscle: Fixes [#376](https://github.com/kemalcr/kemal/pull/376). +- Add context to WebSocket. This allows one to use `HTTP::Server::Context` in `ws` declarations :heart_eyes: Fixes [#349](https://github.com/kemalcr/kemal/pull/349). + +```ruby +ws "/:room_name" do |socket, env| + env.params.url["room_name"] +end +``` + +- Add support for customizing the headers of built-in `Kemal::StaticFileHandler` :hammer: Useful for supporting `CORS` for single page applications :clap: + +```ruby +static_headers do |response, filepath, filestat| + if filepath =~ /\.html$/ + response.headers.add("Access-Control-Allow-Origin", "*") + end + response.headers.add("Content-Size", filestat.size.to_s) + end +end +``` + +- Allow %w in Handler macros [#385](https://github.com/kemalcr/kemal/pull/385). Thanks @will :pray: + +- Security: X-Content-Type-Options: nosniff for static files. Fixes [#379](https://github.com/kemalcr/kemal/issues/379). Thanks @crisward :pray: + +- Performance: [Remove tempfile management to OS](https://github.com/kemalcr/kemal/commit/a1520de7ed3865fa73258343a80fad4f20666a99). This brings %10 - 15 performance boost to Kemal :rocket: + +# 0.20.0 (01-07-2017) + +- Crystal 0.23.0 support! As always, Kemal is compatible with the latest major release of Crystal 💎 +- Great news everyone 🎉 All handlers are now completely ***customizable***!. Use the default `Kemal` handlers or go wild, it's all up to you ⛏ + +```ruby +# Don't forget to add `Kemal::RouteHandler::INSTANCE` or your routes won't work! +Kemal.config.handlers = [Kemal::InitHandler.new, YourHandler.new, Kemal::RouteHandler::INSTANCE] +``` + +You can also insert a handler into a specific position. + +```ruby +# This adds MyCustomHandler instance to 1 position. Be aware that the index starts from 0. +add_handler MyCustomHandler.new, 1 +``` +- Updated [Kilt](https://github.com/jeromegn/kilt) to v0.4.0. +- Make `Route` a `Struct`. This improves the performance of route declarations. + +# 0.19.0 (09-05-2017) + +- Return no body for head route fixes #323. (thanks @crisward) +- Update `radix` to `0.3.8`. (thanks @waghanza) +- User defined context store types. (thanks @neovitange) + +```ruby +class User + property name +end + +add_context_storage_type(User) +``` + +- Prevent `send_file returning filesize. (thanks @crisward) +- Dont call setup in `config#add_filter_handler` fixes #338. + +# 0.18.3 (07-03-2017) + +- Remove `Gzip::Header` monkey patch since it's fixed in `Crystal 0.21.1`. + +# 0.18.2 (24-02-2017) + +- Fix [Gzip in Kemal Seems broken for static files](https://github.com/kemalcr/kemal/issues/316). This was caused by `Gzip::Writer` in `Crystal 0.21.0` and currently mitigated by monkey patching `Gzip::Header`. + +# 0.18.1 (21-02-2017) + +- Crystal 0.21.0 support +- Drop `multipart.cr` dependency. `multipart` support is now built-into Crystal <3 +- Since Crystal 0.21.0 comes built-in with `multipart` there are some improvements and deprecations. + +`meta` has been removed from `FileUpload` and it has the following properties + + + `tmpfile`: This is temporary file for file upload. Useful for saving the upload file. + + `filename`: File name of the file upload. (logo.png, images.zip e.g) + + `headers`: Headers for the file upload. + + `creation_time`: Creation time of the file upload. + + `modification_time`: Last Modification time of the file upload. + + `read_time`: Read time of the file upload. + + `size`: Size of the file upload. + + +# 0.18.0 (11-02-2017) + +- Simpler file upload. File uploads can now be access from `HTTP::Server::Context` like `env.params.files["filename"]`. + +`env.params.files["filename"]` has 5 methods + +- `tmpfile`: This is temporary file for file upload. Useful for saving the upload file. +- `tmpfile_path`: File path of `tmpfile`. +- `filename`: File name of the file upload. (logo.png, images.zip e.g) +- `meta`: Meta information for the file upload. +- `headers`: Headers for the file upload. + +Here's a fully working sample for reading a image file upload `image1` and saving it under `public/uploads`. + + ```crystal +post "/upload" do |env| + file = env.params.files["image1"].tmpfile + file_path = ::File.join [Kemal.config.public_folder, "uploads/", file.filename] + File.open(file_path, "w") do |f| + IO.copy(file, f) + end + "Upload ok" +end + ``` + +To test + +`curl -F "image1=@/Users/serdar/Downloads/kemal.png" http://localhost:3000/upload` + +- RF7233 support a.k.a file streaming. (https://github.com/kemalcr/kemal/pull/299) (thanks @denysvitali) + +- Update Radix to 0.3.7. Fixes https://github.com/kemalcr/kemal/issues/293 +- Configurable startup / shutdown logging. https://github.com/kemalcr/kemal/issues/291 and https://github.com/kemalcr/kemal/issues/292 (thanks @twisterghost). + +# 0.17.5 (09-01-2017) + +- Update multipart.cr to 0.1.2. Fixes #285 related to multipart.cr + +# 0.17.4 (24-12-2016) + +- Support for Crystal 0.20.3 +- Add `Kemal.stop`. Fixes #269. +- `HTTP::Handler` is not a class anymore, it's a module. See https://github.com/crystal-lang/crystal/releases/tag/0.20.3 + +# 0.17.3 (03-12-2016) + +- Handle missing 404 image. Fixes #263 +- Remove basic auth middleware from core and move to [kemalcr/kemal-basic-auth](https://github.com/kemalcr/kemal-basic-auth). + +# 0.17.2 (25-11-2016) + +- Use body.gets_to_end for parse_json. Fixes #260. +- Update Radix to 0.3.5 and lock pessimistically. (thanks @luislavena) + +# 0.17.1 (24-11-2016) + +- Treat `HTTP::Request` body as an `IO`. Fixes [#257](https://github.com/sdogruyol/kemal/issues/257) + +# 0.17.0 (23-11-2016) + +- Reimplemented Request middleware / filter routing. + +Now all requests will first go through the Middleware stack then Filters (before_*) and will finally reach the matching route. + +Which is illustrated as, + +``` +Request -> Middleware -> Filter -> Route +``` + +- Rename `return_with` as `halt`. +- Route declaration must start with `/`. Fixes [#242](https://github.com/sdogruyol/kemal/issues/242) +- Set default exception Content-Type to text/html. Fixes [#202](https://github.com/sdogruyol/kemal/issues/242) +- Add `only` and `exclude` paths for `Kemal::Handler`. This change requires that all handlers must inherit from `Kemal::Handler`. + +For example this handler will only work on `/` path. By default the HTTP method is `GET`. + + +```crystal +class OnlyHandler < Kemal::Handler + only ["/"] + + def call(env) + return call_next(env) unless only_match?(env) + puts "If the path is / i will be doing some processing here." + end +end +``` + +The handlers using `exclude` will work on the paths that isn't specified. For example this handler will work on any routes other than `/`. + +```crystal +class ExcludeHandler < Kemal::Handler + exclude ["/"] + + def call(env) + return call_next(env) unless only_match?(env) + puts "If the path is NOT / i will be doing some processing here." + end +end +``` + +- Close response on `halt`. (thanks @samueleaton). +- Update `Radix` to `v0.3.4`. +- `error` handler now also yields error. For example you can get the error mesasage like + +```crystal + error 500 do |env, err| + err.message + end +``` + +- Update `multipart.cr` to `v0.1.1` + +# 0.16.1 (12-10-2016) + +- Improved Multipart support with more info on parsed files. `parse_multipart(env)` now yields +an `UploadFile` object which has the following properties `field`,`data`,`meta`,`headers. + +```crystal +post "/upload" do |env| + parse_multipart(env) do |f| + image1 = f.data if f.field == "image1" + image2 = f.data if f.field == "image2" + puts f.meta + puts f.headers + "Upload complete" + end +end +``` + +# 0.16.0 + +- Multipart support <3 (thanks @RX14). Now you can handle file uploads. + +```crystal +post "/upload" do |env| + parse_multipart(env) do |field, data| + image1 = data if field == "image1" + image2 = data if field == "image2" + "Upload complete" + end +end +``` + +- Make session configurable. Now you can specify session name and expire time wit + +```crystal +Kemal.config.session["name"] = "your_app" +Kemal.config.session["expire_time"] = 48.hours +``` + +- Session now supports more types. (String, Int32, Float64, Bool) +- Add `gzip` helper to enable / disable gzip compression on responses. +- Static file caching with etag and gzip (thanks @crisward) +- `Kemal.run` now accepts port to listen. + +# 0.15.1 (05-09-2016) + +- Don't forget to call_next on NullLogHandler + +# 0.15.0 (03-09-2016) + +- Add context store +- `KEMAL_ENV` respects to `Kemal.config.env` and needs to be explicitly set. +- `Kemal::InitHandler` is introduced. Adds initial configuration, headers like `X-Powered-By`. +- Add `send_file` to helpers. +- Add mime types. +- Fix parsing JSON params when "charset" is present in "Content-Type" header. +- Use http-only cookie for session +- Inject STDOUT by default in CommonLogHandler diff --git a/lib/kemal/LICENSE b/lib/kemal/LICENSE new file mode 100644 index 00000000..4b1bea44 --- /dev/null +++ b/lib/kemal/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2016 Serdar Doğruyol + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE diff --git a/lib/kemal/README.md b/lib/kemal/README.md new file mode 100644 index 00000000..068cc25a --- /dev/null +++ b/lib/kemal/README.md @@ -0,0 +1,67 @@ + +[![Kemal](https://avatars3.githubusercontent.com/u/15321198?v=3&s=200)](http://kemalcr.com) + +# Kemal + +Lightning Fast, Super Simple web framework. + +[![Build Status](https://travis-ci.org/kemalcr/kemal.svg?branch=master)](https://travis-ci.org/kemalcr/kemal) +[![Join the chat at https://gitter.im/sdogruyol/kemal](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/sdogruyol/kemal?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) + +# Super Simple ⚡️ + +```ruby +require "kemal" + +# Matches GET "http://host:port/" +get "/" do + "Hello World!" +end + +# Creates a WebSocket handler. +# Matches "ws://host:port/socket" +ws "/socket" do |socket| + socket.send "Hello from Kemal!" +end + +Kemal.run +``` + +Start your application! + +``` +crystal src/kemal_sample.cr +``` +Go to *http://localhost:3000* + +Check [documentation](http://kemalcr.com) or [samples](https://github.com/kemalcr/kemal/tree/master/samples) for more. + +# Installation + +Add this to your application's `shard.yml`: + +```yaml +dependencies: + kemal: + github: kemalcr/kemal +``` + +See also [Getting Started](http://kemalcr.com/guide/). + +# Features + +- Support all REST verbs +- Websocket support +- Request/Response context, easy parameter handling +- Middleware support +- Built-in JSON support +- Built-in static file serving +- Built-in view templating via [Kilt](https://github.com/jeromegn/kilt) + +# Documentation + +You can read the documentation at the official site [kemalcr.com](http://kemalcr.com) + +## Thanks + +Thanks to Manas for their awesome work on [Frank](https://github.com/manastech/frank). diff --git a/lib/kemal/samples/hello_world.cr b/lib/kemal/samples/hello_world.cr new file mode 100644 index 00000000..c04f1d58 --- /dev/null +++ b/lib/kemal/samples/hello_world.cr @@ -0,0 +1,8 @@ +require "kemal" + +# Set root. If not specified the default content_type is 'text' +get "/" do + "Hello Kemal!" +end + +Kemal.run diff --git a/lib/kemal/samples/json_api.cr b/lib/kemal/samples/json_api.cr new file mode 100644 index 00000000..0132c149 --- /dev/null +++ b/lib/kemal/samples/json_api.cr @@ -0,0 +1,11 @@ +require "kemal" +require "json" + +# You can easily access the context and set content_type like 'application/json'. +# Look how easy to build a JSON serving API. +get "/" do |env| + env.response.content_type = "application/json" + {name: "Serdar", age: 27}.to_json +end + +Kemal.run diff --git a/lib/kemal/samples/websocket_server.cr b/lib/kemal/samples/websocket_server.cr new file mode 100644 index 00000000..61a08029 --- /dev/null +++ b/lib/kemal/samples/websocket_server.cr @@ -0,0 +1,11 @@ +require "kemal" + +ws "/" do |socket| + socket.send "Hello from Kemal!" + + socket.on_message do |message| + socket.send "Echo back from server #{message}" + end +end + +Kemal.run diff --git a/lib/kemal/shard.yml b/lib/kemal/shard.yml new file mode 100644 index 00000000..963884f0 --- /dev/null +++ b/lib/kemal/shard.yml @@ -0,0 +1,25 @@ +name: kemal +version: 0.26.0 + +authors: + - Serdar Dogruyol + +dependencies: + radix: + github: luislavena/radix + version: ~> 0.3.8 + kilt: + github: jeromegn/kilt + version: ~> 0.4.0 + exception_page: + github: crystal-loot/exception_page + version: ~> 0.1.1 + +development_dependencies: + ameba: + github: veelenga/ameba + version: ~> 0.10.0 + +crystal: 0.30.0 + +license: MIT diff --git a/lib/kemal/spec/all_spec.cr b/lib/kemal/spec/all_spec.cr new file mode 100644 index 00000000..938dadf6 --- /dev/null +++ b/lib/kemal/spec/all_spec.cr @@ -0,0 +1 @@ +require "./*" diff --git a/lib/kemal/spec/asset/hello.ecr b/lib/kemal/spec/asset/hello.ecr new file mode 100644 index 00000000..1cc8d414 --- /dev/null +++ b/lib/kemal/spec/asset/hello.ecr @@ -0,0 +1 @@ +Hello <%= name %> diff --git a/lib/kemal/spec/asset/hello_with_content_for.ecr b/lib/kemal/spec/asset/hello_with_content_for.ecr new file mode 100644 index 00000000..149b2944 --- /dev/null +++ b/lib/kemal/spec/asset/hello_with_content_for.ecr @@ -0,0 +1,5 @@ +Hello <%= name %> + +<% content_for "custom" do %> +

Hello from otherside

+<% end %> \ No newline at end of file diff --git a/lib/kemal/spec/asset/layout.ecr b/lib/kemal/spec/asset/layout.ecr new file mode 100644 index 00000000..d493b59d --- /dev/null +++ b/lib/kemal/spec/asset/layout.ecr @@ -0,0 +1 @@ +<%= content %> diff --git a/lib/kemal/spec/asset/layout_with_yield.ecr b/lib/kemal/spec/asset/layout_with_yield.ecr new file mode 100644 index 00000000..f6cd6736 --- /dev/null +++ b/lib/kemal/spec/asset/layout_with_yield.ecr @@ -0,0 +1,6 @@ + + + <%= content %> + <%= yield_content "custom" %> + + \ No newline at end of file diff --git a/lib/kemal/spec/asset/layout_with_yield_and_vars.ecr b/lib/kemal/spec/asset/layout_with_yield_and_vars.ecr new file mode 100644 index 00000000..3a82a7a6 --- /dev/null +++ b/lib/kemal/spec/asset/layout_with_yield_and_vars.ecr @@ -0,0 +1,8 @@ + + + <%= content %> + <%= yield_content "custom" %> + <%= var1 %> + <%= var2 %> + + \ No newline at end of file diff --git a/lib/kemal/spec/config_spec.cr b/lib/kemal/spec/config_spec.cr new file mode 100644 index 00000000..31a84380 --- /dev/null +++ b/lib/kemal/spec/config_spec.cr @@ -0,0 +1,61 @@ +require "./spec_helper" + +describe "Config" do + it "sets default port to 3000" do + Kemal::Config.new.port.should eq 3000 + end + + it "sets default environment to development" do + Kemal::Config.new.env.should eq "development" + end + + it "sets environment to production" do + config = Kemal.config + config.env = "production" + config.env.should eq "production" + end + + it "sets default powered_by_header to true" do + Kemal::Config.new.powered_by_header.should be_true + end + + it "sets host binding" do + config = Kemal.config + config.host_binding = "127.0.0.1" + config.host_binding.should eq "127.0.0.1" + end + + it "adds a custom handler" do + config = Kemal.config + config.add_handler CustomTestHandler.new + Kemal.config.setup + config.handlers.size.should eq(7) + end + + it "toggles the shutdown message" do + config = Kemal.config + config.shutdown_message = false + config.shutdown_message.should eq false + config.shutdown_message = true + config.shutdown_message.should eq true + end + + it "adds custom options" do + config = Kemal.config + ARGV.push("--test") + ARGV.push("FOOBAR") + test_option = nil + + config.extra_options do |parser| + parser.on("--test TEST_OPTION", "Test an option") do |opt| + test_option = opt + end + end + Kemal::CLI.new ARGV + test_option.should eq("FOOBAR") + end + + it "gets the version from shards.yml" do + Kemal::VERSION.should_not be("") + end +end diff --git a/lib/kemal/spec/context_spec.cr b/lib/kemal/spec/context_spec.cr new file mode 100644 index 00000000..c9729266 --- /dev/null +++ b/lib/kemal/spec/context_spec.cr @@ -0,0 +1,107 @@ +require "./spec_helper" + +describe "Context" do + context "headers" do + it "sets content type" do + get "/" do |env| + env.response.content_type = "application/json" + "Hello" + end + request = HTTP::Request.new("GET", "/") + client_response = call_request_on_app(request) + client_response.headers["Content-Type"].should eq("application/json") + end + + it "parses headers" do + get "/" do |env| + name = env.request.headers["name"] + "Hello #{name}" + end + headers = HTTP::Headers.new + headers["name"] = "kemal" + request = HTTP::Request.new("GET", "/", headers) + client_response = call_request_on_app(request) + client_response.body.should eq "Hello kemal" + end + + it "sets response headers" do + get "/" do |env| + env.response.headers.add "Accept-Language", "tr" + end + request = HTTP::Request.new("GET", "/") + client_response = call_request_on_app(request) + client_response.headers["Accept-Language"].should eq "tr" + end + end + + context "storage" do + it "can store primitive types" do + before_get "/" do |env| + env.set "before_get", "Kemal" + env.set "before_get_int", 123 + env.set "before_get_float", 3.5 + end + + get "/" do |env| + { + before_get: env.get("before_get"), + before_get_int: env.get("before_get_int"), + before_get_float: env.get("before_get_float"), + } + end + + request = HTTP::Request.new("GET", "/") + io = IO::Memory.new + response = HTTP::Server::Response.new(io) + context = HTTP::Server::Context.new(request, response) + Kemal::FilterHandler::INSTANCE.call(context) + Kemal::RouteHandler::INSTANCE.call(context) + + context.get("before_get").should eq "Kemal" + context.get("before_get_int").should eq 123 + context.get("before_get_float").should eq 3.5 + end + + it "can store custom types" do + before_get "/" do |env| + t = TestContextStorageType.new + t.id = 32 + a = AnotherContextStorageType.new + + env.set "before_get_context_test", t + env.set "another_context_test", a + end + + get "/" do |env| + { + before_get_context_test: env.get("before_get_context_test"), + another_context_test: env.get("another_context_test"), + } + end + + request = HTTP::Request.new("GET", "/") + io = IO::Memory.new + response = HTTP::Server::Response.new(io) + context = HTTP::Server::Context.new(request, response) + Kemal::FilterHandler::INSTANCE.call(context) + Kemal::RouteHandler::INSTANCE.call(context) + + context.get("before_get_context_test").as(TestContextStorageType).id.should eq 32 + context.get("another_context_test").as(AnotherContextStorageType).name.should eq "kemal-context" + end + + it "fetches non-existent keys from store with get?" do + get "/" { } + + request = HTTP::Request.new("GET", "/") + io = IO::Memory.new + response = HTTP::Server::Response.new(io) + context = HTTP::Server::Context.new(request, response) + Kemal::FilterHandler::INSTANCE.call(context) + Kemal::RouteHandler::INSTANCE.call(context) + + context.get?("non_existent_key").should be_nil + context.get?("another_non_existent_key").should be_nil + end + end +end diff --git a/lib/kemal/spec/exception_handler_spec.cr b/lib/kemal/spec/exception_handler_spec.cr new file mode 100644 index 00000000..b9519e9d --- /dev/null +++ b/lib/kemal/spec/exception_handler_spec.cr @@ -0,0 +1,115 @@ +require "./spec_helper" + +describe "Kemal::ExceptionHandler" do + it "renders 404 on route not found" do + get "/" do + "Hello" + end + + request = HTTP::Request.new("GET", "/asd") + io = IO::Memory.new + response = HTTP::Server::Response.new(io) + context = HTTP::Server::Context.new(request, response) + Kemal::ExceptionHandler::INSTANCE.call(context) + response.close + io.rewind + response = HTTP::Client::Response.from_io(io, decompress: false) + response.status_code.should eq 404 + end + + it "renders custom error" do + error 403 do + "403 error" + end + get "/" do |env| + env.response.status_code = 403 + end + request = HTTP::Request.new("GET", "/") + io = IO::Memory.new + response = HTTP::Server::Response.new(io) + context = HTTP::Server::Context.new(request, response) + Kemal::ExceptionHandler::INSTANCE.next = Kemal::RouteHandler::INSTANCE + Kemal::ExceptionHandler::INSTANCE.call(context) + response.close + io.rewind + response = HTTP::Client::Response.from_io(io, decompress: false) + response.status_code.should eq 403 + response.headers["Content-Type"].should eq "text/html" + response.body.should eq "403 error" + end + + it "renders custom 500 error" do + error 500 do + "Something happened" + end + get "/" do |env| + env.response.status_code = 500 + end + request = HTTP::Request.new("GET", "/") + io = IO::Memory.new + response = HTTP::Server::Response.new(io) + context = HTTP::Server::Context.new(request, response) + Kemal::ExceptionHandler::INSTANCE.next = Kemal::RouteHandler::INSTANCE + Kemal::ExceptionHandler::INSTANCE.call(context) + response.close + io.rewind + response = HTTP::Client::Response.from_io(io, decompress: false) + response.status_code.should eq 500 + response.headers["Content-Type"].should eq "text/html" + response.body.should eq "Something happened" + end + + it "keeps the specified error Content-Type" do + error 500 do + "Something happened" + end + get "/" do |env| + env.response.content_type = "application/json" + env.response.status_code = 500 + end + request = HTTP::Request.new("GET", "/") + io = IO::Memory.new + response = HTTP::Server::Response.new(io) + context = HTTP::Server::Context.new(request, response) + Kemal::ExceptionHandler::INSTANCE.next = Kemal::RouteHandler::INSTANCE + Kemal::ExceptionHandler::INSTANCE.call(context) + response.close + io.rewind + response = HTTP::Client::Response.from_io(io, decompress: false) + response.status_code.should eq 500 + response.headers["Content-Type"].should eq "application/json" + response.body.should eq "Something happened" + end + + it "renders custom error with env and error" do + error 500 do |_, err| + err.message + end + get "/" do |env| + env.response.content_type = "application/json" + env.response.status_code = 500 + end + request = HTTP::Request.new("GET", "/") + io = IO::Memory.new + response = HTTP::Server::Response.new(io) + context = HTTP::Server::Context.new(request, response) + Kemal::ExceptionHandler::INSTANCE.next = Kemal::RouteHandler::INSTANCE + Kemal::ExceptionHandler::INSTANCE.call(context) + response.close + io.rewind + response = HTTP::Client::Response.from_io(io, decompress: false) + response.status_code.should eq 500 + response.headers["Content-Type"].should eq "application/json" + response.body.should eq "Rendered error with 500" + end + + it "does not do anything on a closed io" do + get "/" do |env| + halt env, status_code: 404 + end + + request = HTTP::Request.new("GET", "/") + client_response = call_request_on_app(request) + client_response.status_code.should eq 404 + end +end diff --git a/lib/kemal/spec/handler_spec.cr b/lib/kemal/spec/handler_spec.cr new file mode 100644 index 00000000..9b1019fc --- /dev/null +++ b/lib/kemal/spec/handler_spec.cr @@ -0,0 +1,161 @@ +require "./spec_helper" + +class CustomTestHandler < Kemal::Handler + def call(env) + env.response << "Kemal" + call_next env + end +end + +class OnlyHandler < Kemal::Handler + only ["/only"] + + def call(env) + return call_next(env) unless only_match?(env) + env.response.print "Only" + call_next env + end +end + +class ExcludeHandler < Kemal::Handler + exclude ["/exclude"] + + def call(env) + return call_next(env) if exclude_match?(env) + env.response.print "Exclude" + call_next env + end +end + +class PostOnlyHandler < Kemal::Handler + only ["/only", "/route1", "/route2"], "POST" + + def call(env) + return call_next(env) unless only_match?(env) + env.response.print "Only" + call_next env + end +end + +class PostExcludeHandler < Kemal::Handler + exclude ["/exclude"], "POST" + + def call(env) + return call_next(env) if exclude_match?(env) + env.response.print "Exclude" + call_next env + end +end + +class ExcludeHandlerPercentW < Kemal::Handler + exclude %w[/exclude] + + def call(env) + return call_next(env) if exclude_match?(env) + env.response.print "Exclude" + call_next env + end +end + +class PostOnlyHandlerPercentW < Kemal::Handler + only %w[/only /route1 /route2], "POST" + + def call(env) + return call_next(env) unless only_match?(env) + env.response.print "Only" + call_next env + end +end + +describe "Handler" do + it "adds custom handler before before_*" do + filter_middleware = Kemal::FilterHandler.new + filter_middleware._add_route_filter("GET", "/", :before) do |env| + env.response << " is" + end + + filter_middleware._add_route_filter("GET", "/", :before) do |env| + env.response << " so" + end + add_handler CustomTestHandler.new + + get "/" do + " Great" + end + request = HTTP::Request.new("GET", "/") + client_response = call_request_on_app(request) + client_response.status_code.should eq(200) + client_response.body.should eq("Kemal is so Great") + end + + it "runs specified only_routes in middleware" do + get "/only" do + "Get" + end + add_handler OnlyHandler.new + request = HTTP::Request.new("GET", "/only") + client_response = call_request_on_app(request) + client_response.body.should eq "OnlyGet" + end + + it "doesn't run specified exclude_routes in middleware" do + get "/" do + "Get" + end + get "/exclude" do + "Exclude" + end + add_handler ExcludeHandler.new + request = HTTP::Request.new("GET", "/") + client_response = call_request_on_app(request) + client_response.body.should eq "ExcludeGet" + end + + it "runs specified only_routes with method in middleware" do + post "/only" do + "Post" + end + get "/only" do + "Get" + end + add_handler PostOnlyHandler.new + request = HTTP::Request.new("POST", "/only") + client_response = call_request_on_app(request) + client_response.body.should eq "OnlyPost" + end + + it "doesn't run specified exclude_routes with method in middleware" do + post "/exclude" do + "Post" + end + post "/only" do + "Post" + end + add_handler PostOnlyHandler.new + add_handler PostExcludeHandler.new + request = HTTP::Request.new("POST", "/only") + client_response = call_request_on_app(request) + client_response.body.should eq "OnlyExcludePost" + end + + it "adds a handler at given position" do + post_handler = PostOnlyHandler.new + add_handler post_handler, 1 + Kemal.config.setup + Kemal.config.handlers[1].should eq post_handler + end + + it "assigns custom handlers" do + post_only_handler = PostOnlyHandler.new + post_exclude_handler = PostExcludeHandler.new + Kemal.config.handlers = [post_only_handler, post_exclude_handler] + Kemal.config.handlers.should eq [post_only_handler, post_exclude_handler] + end + + it "is able to use %w in macros" do + post_only_handler = PostOnlyHandlerPercentW.new + exclude_handler = ExcludeHandlerPercentW.new + Kemal.config.handlers = [post_only_handler, exclude_handler] + Kemal.config.handlers.should eq [post_only_handler, exclude_handler] + end +end diff --git a/lib/kemal/spec/helpers_spec.cr b/lib/kemal/spec/helpers_spec.cr new file mode 100644 index 00000000..77737bd6 --- /dev/null +++ b/lib/kemal/spec/helpers_spec.cr @@ -0,0 +1,155 @@ +require "./spec_helper" +require "./handler_spec" + +describe "Macros" do + describe "#public_folder" do + it "sets public folder" do + public_folder "/some/path/to/folder" + Kemal.config.public_folder.should eq("/some/path/to/folder") + end + end + + describe "#add_handler" do + it "adds a custom handler" do + add_handler CustomTestHandler.new + Kemal.config.setup + Kemal.config.handlers.size.should eq 7 + end + end + + describe "#logging" do + it "sets logging status" do + logging false + Kemal.config.logging.should eq false + end + + it "sets a custom logger" do + config = Kemal::Config::INSTANCE + logger CustomLogHandler.new + config.logger.should be_a(CustomLogHandler) + end + end + + describe "#halt" do + it "can break block with halt macro" do + get "/non-breaking" do + "hello" + "world" + end + request = HTTP::Request.new("GET", "/non-breaking") + client_response = call_request_on_app(request) + client_response.status_code.should eq(200) + client_response.body.should eq("world") + + get "/breaking" do |env| + halt env, 404, "hello" + "world" + end + request = HTTP::Request.new("GET", "/breaking") + client_response = call_request_on_app(request) + client_response.status_code.should eq(404) + client_response.body.should eq("hello") + end + + it "can break block with halt macro using default values" do + get "/" do |env| + halt env + "world" + end + request = HTTP::Request.new("GET", "/") + client_response = call_request_on_app(request) + client_response.status_code.should eq(200) + client_response.body.should eq("") + end + end + + describe "#headers" do + it "can add headers" do + get "/headers" do |env| + env.response.headers.add "Content-Type", "image/png" + headers env, { + "Access-Control-Allow-Origin" => "*", + "Content-Type" => "text/plain", + } + end + request = HTTP::Request.new("GET", "/headers") + response = call_request_on_app(request) + response.headers["Access-Control-Allow-Origin"].should eq("*") + response.headers["Content-Type"].should eq("text/plain") + end + end + + describe "#send_file" do + it "sends file with given path and default mime-type" do + get "/" do |env| + send_file env, "./spec/asset/hello.ecr" + end + + request = HTTP::Request.new("GET", "/") + response = call_request_on_app(request) + response.status_code.should eq(200) + response.headers["Content-Type"].should eq("application/octet-stream") + response.headers["Content-Length"].should eq("18") + end + + it "sends file with given path and given mime-type" do + get "/" do |env| + send_file env, "./spec/asset/hello.ecr", "image/jpeg" + end + + request = HTTP::Request.new("GET", "/") + response = call_request_on_app(request) + response.status_code.should eq(200) + response.headers["Content-Type"].should eq("image/jpeg") + response.headers["Content-Length"].should eq("18") + end + + it "sends file with binary stream" do + get "/" do |env| + send_file env, "Serdar".to_slice + end + + request = HTTP::Request.new("GET", "/") + response = call_request_on_app(request) + response.status_code.should eq(200) + response.headers["Content-Type"].should eq("application/octet-stream") + response.headers["Content-Length"].should eq("6") + end + + it "sends file with given path and given filename" do + get "/" do |env| + send_file env, "./spec/asset/hello.ecr", filename: "image.jpg" + end + + request = HTTP::Request.new("GET", "/") + response = call_request_on_app(request) + response.status_code.should eq(200) + response.headers["Content-Disposition"].should eq("attachment; filename=\"image.jpg\"") + end + end + + describe "#gzip" do + it "adds HTTP::CompressHandler to handlers" do + gzip true + Kemal.config.setup + Kemal.config.handlers[4].should be_a(HTTP::CompressHandler) + end + end + + describe "#serve_static" do + it "should disable static file hosting" do + serve_static false + Kemal.config.serve_static.should eq false + end + + it "should disble enable gzip and dir_listing" do + serve_static({"gzip" => true, "dir_listing" => true}) + conf = Kemal.config.serve_static + conf.is_a?(Hash).should eq true + if conf.is_a?(Hash) + conf["gzip"].should eq true + conf["dir_listing"].should eq true + end + end + end +end diff --git a/lib/kemal/spec/init_handler_spec.cr b/lib/kemal/spec/init_handler_spec.cr new file mode 100644 index 00000000..601bbc1d --- /dev/null +++ b/lib/kemal/spec/init_handler_spec.cr @@ -0,0 +1,32 @@ +require "./spec_helper" + +describe "Kemal::InitHandler" do + it "initializes context with Content-Type: text/html" do + request = HTTP::Request.new("GET", "/") + io = IO::Memory.new + response = HTTP::Server::Response.new(io) + context = HTTP::Server::Context.new(request, response) + Kemal::InitHandler::INSTANCE.next = ->(_context : HTTP::Server::Context) {} + Kemal::InitHandler::INSTANCE.call(context) + context.response.headers["Content-Type"].should eq "text/html" + end + + it "initializes context with X-Powered-By: Kemal" do + request = HTTP::Request.new("GET", "/") + io = IO::Memory.new + response = HTTP::Server::Response.new(io) + context = HTTP::Server::Context.new(request, response) + Kemal::InitHandler::INSTANCE.call(context) + context.response.headers["X-Powered-By"].should eq "Kemal" + end + + it "does not initialize context with X-Powered-By: Kemal if disabled" do + Kemal.config.powered_by_header = false + request = HTTP::Request.new("GET", "/") + io = IO::Memory.new + response = HTTP::Server::Response.new(io) + context = HTTP::Server::Context.new(request, response) + Kemal::InitHandler::INSTANCE.call(context) + context.response.headers["X-Powered-By"]?.should be_nil + end +end diff --git a/lib/kemal/spec/log_handler_spec.cr b/lib/kemal/spec/log_handler_spec.cr new file mode 100644 index 00000000..5ee9c863 --- /dev/null +++ b/lib/kemal/spec/log_handler_spec.cr @@ -0,0 +1,21 @@ +require "./spec_helper" + +describe "Kemal::LogHandler" do + it "logs to the given IO" do + io = IO::Memory.new + logger = Kemal::LogHandler.new io + logger.write "Something" + io.to_s.should eq "Something" + end + + it "creates log message for each request" do + request = HTTP::Request.new("GET", "/") + io = IO::Memory.new + context_io = IO::Memory.new + response = HTTP::Server::Response.new(context_io) + context = HTTP::Server::Context.new(request, response) + logger = Kemal::LogHandler.new io + logger.call(context) + io.to_s.should_not be nil + end +end diff --git a/lib/kemal/spec/middleware/filters_spec.cr b/lib/kemal/spec/middleware/filters_spec.cr new file mode 100644 index 00000000..9bc2564c --- /dev/null +++ b/lib/kemal/spec/middleware/filters_spec.cr @@ -0,0 +1,190 @@ +require "../spec_helper" + +describe "Kemal::FilterHandler" do + it "executes code before home request" do + test_filter = FilterTest.new + test_filter.modified = "false" + + filter_middleware = Kemal::FilterHandler.new + filter_middleware._add_route_filter("GET", "/greetings", :before) { test_filter.modified = "true" } + + kemal = Kemal::RouteHandler::INSTANCE + kemal.add_route "GET", "/greetings" { test_filter.modified } + + test_filter.modified.should eq("false") + request = HTTP::Request.new("GET", "/greetings") + create_request_and_return_io_and_context(filter_middleware, request) + io_with_context = create_request_and_return_io_and_context(kemal, request)[0] + client_response = HTTP::Client::Response.from_io(io_with_context, decompress: false) + client_response.body.should eq("true") + end + + it "executes code before GET home request but not POST home request" do + test_filter = FilterTest.new + test_filter.modified = "false" + + filter_middleware = Kemal::FilterHandler.new + filter_middleware._add_route_filter("GET", "/greetings", :before) { test_filter.modified = test_filter.modified == "true" ? "false" : "true" } + + kemal = Kemal::RouteHandler::INSTANCE + kemal.add_route "GET", "/greetings" { test_filter.modified } + kemal.add_route "POST", "/greetings" { test_filter.modified } + + test_filter.modified.should eq("false") + + request = HTTP::Request.new("GET", "/greetings") + create_request_and_return_io_and_context(filter_middleware, request) + io_with_context = create_request_and_return_io_and_context(kemal, request)[0] + client_response = HTTP::Client::Response.from_io(io_with_context, decompress: false) + client_response.body.should eq("true") + + request = HTTP::Request.new("POST", "/greetings") + create_request_and_return_io_and_context(filter_middleware, request) + io_with_context = create_request_and_return_io_and_context(kemal, request)[0] + client_response = HTTP::Client::Response.from_io(io_with_context, decompress: false) + client_response.body.should eq("true") + end + + it "executes code before all GET/POST home request" do + test_filter = FilterTest.new + test_filter.modified = "false" + + filter_middleware = Kemal::FilterHandler.new + filter_middleware._add_route_filter("ALL", "/greetings", :before) { test_filter.modified = test_filter.modified == "true" ? "false" : "true" } + filter_middleware._add_route_filter("GET", "/greetings", :before) { test_filter.modified = test_filter.modified == "true" ? "false" : "true" } + filter_middleware._add_route_filter("POST", "/greetings", :before) { test_filter.modified = test_filter.modified == "true" ? "false" : "true" } + + kemal = Kemal::RouteHandler::INSTANCE + kemal.add_route "GET", "/greetings" { test_filter.modified } + kemal.add_route "POST", "/greetings" { test_filter.modified } + + test_filter.modified.should eq("false") + + request = HTTP::Request.new("GET", "/greetings") + create_request_and_return_io_and_context(filter_middleware, request) + io_with_context = create_request_and_return_io_and_context(kemal, request)[0] + client_response = HTTP::Client::Response.from_io(io_with_context, decompress: false) + client_response.body.should eq("false") + + request = HTTP::Request.new("POST", "/greetings") + create_request_and_return_io_and_context(filter_middleware, request) + io_with_context = create_request_and_return_io_and_context(kemal, request)[0] + client_response = HTTP::Client::Response.from_io(io_with_context, decompress: false) + client_response.body.should eq("false") + end + + it "executes code after home request" do + test_filter = FilterTest.new + test_filter.modified = "false" + + filter_middleware = Kemal::FilterHandler.new + filter_middleware._add_route_filter("GET", "/greetings", :after) { test_filter.modified = "true" } + + kemal = Kemal::RouteHandler::INSTANCE + kemal.add_route "GET", "/greetings" { test_filter.modified } + + test_filter.modified.should eq("false") + request = HTTP::Request.new("GET", "/greetings") + create_request_and_return_io_and_context(filter_middleware, request) + io_with_context = create_request_and_return_io_and_context(kemal, request)[0] + client_response = HTTP::Client::Response.from_io(io_with_context, decompress: false) + client_response.body.should eq("true") + end + + it "executes code after GET home request but not POST home request" do + test_filter = FilterTest.new + test_filter.modified = "false" + + filter_middleware = Kemal::FilterHandler.new + filter_middleware._add_route_filter("GET", "/greetings", :after) { test_filter.modified = test_filter.modified == "true" ? "false" : "true" } + + kemal = Kemal::RouteHandler::INSTANCE + kemal.add_route "GET", "/greetings" { test_filter.modified } + kemal.add_route "POST", "/greetings" { test_filter.modified } + + test_filter.modified.should eq("false") + + request = HTTP::Request.new("GET", "/greetings") + create_request_and_return_io_and_context(filter_middleware, request) + io_with_context = create_request_and_return_io_and_context(kemal, request)[0] + client_response = HTTP::Client::Response.from_io(io_with_context, decompress: false) + client_response.body.should eq("true") + + request = HTTP::Request.new("POST", "/greetings") + create_request_and_return_io_and_context(filter_middleware, request) + io_with_context = create_request_and_return_io_and_context(kemal, request)[0] + client_response = HTTP::Client::Response.from_io(io_with_context, decompress: false) + client_response.body.should eq("true") + end + + it "executes code after all GET/POST home request" do + test_filter = FilterTest.new + test_filter.modified = "false" + + filter_middleware = Kemal::FilterHandler.new + filter_middleware._add_route_filter("ALL", "/greetings", :after) { test_filter.modified = test_filter.modified == "true" ? "false" : "true" } + filter_middleware._add_route_filter("GET", "/greetings", :after) { test_filter.modified = test_filter.modified == "true" ? "false" : "true" } + filter_middleware._add_route_filter("POST", "/greetings", :after) { test_filter.modified = test_filter.modified == "true" ? "false" : "true" } + + kemal = Kemal::RouteHandler::INSTANCE + kemal.add_route "GET", "/greetings" { test_filter.modified } + kemal.add_route "POST", "/greetings" { test_filter.modified } + + test_filter.modified.should eq("false") + request = HTTP::Request.new("GET", "/greetings") + create_request_and_return_io_and_context(filter_middleware, request) + io_with_context = create_request_and_return_io_and_context(kemal, request)[0] + client_response = HTTP::Client::Response.from_io(io_with_context, decompress: false) + client_response.body.should eq("false") + + request = HTTP::Request.new("POST", "/greetings") + create_request_and_return_io_and_context(filter_middleware, request) + io_with_context = create_request_and_return_io_and_context(kemal, request)[0] + client_response = HTTP::Client::Response.from_io(io_with_context, decompress: false) + client_response.body.should eq("false") + end + + it "executes 3 differents blocks after all request" do + test_filter = FilterTest.new + test_filter.modified = "false" + test_filter_second = FilterTest.new + test_filter_second.modified = "false" + test_filter_third = FilterTest.new + test_filter_third.modified = "false" + + filter_middleware = Kemal::FilterHandler.new + filter_middleware._add_route_filter("ALL", "/greetings", :before) { test_filter.modified = test_filter.modified == "true" ? "false" : "true" } + filter_middleware._add_route_filter("ALL", "/greetings", :before) { test_filter_second.modified = test_filter_second.modified == "true" ? "false" : "true" } + filter_middleware._add_route_filter("ALL", "/greetings", :before) { test_filter_third.modified = test_filter_third.modified == "true" ? "false" : "true" } + + kemal = Kemal::RouteHandler::INSTANCE + kemal.add_route "GET", "/greetings" { test_filter.modified } + kemal.add_route "POST", "/greetings" { test_filter_second.modified } + kemal.add_route "PUT", "/greetings" { test_filter_third.modified } + + test_filter.modified.should eq("false") + test_filter_second.modified.should eq("false") + test_filter_third.modified.should eq("false") + request = HTTP::Request.new("GET", "/greetings") + create_request_and_return_io_and_context(filter_middleware, request) + io_with_context = create_request_and_return_io_and_context(kemal, request)[0] + client_response = HTTP::Client::Response.from_io(io_with_context, decompress: false) + client_response.body.should eq("true") + + request = HTTP::Request.new("POST", "/greetings") + create_request_and_return_io_and_context(filter_middleware, request) + io_with_context = create_request_and_return_io_and_context(kemal, request)[0] + client_response = HTTP::Client::Response.from_io(io_with_context, decompress: false) + client_response.body.should eq("false") + + request = HTTP::Request.new("PUT", "/greetings") + create_request_and_return_io_and_context(filter_middleware, request) + io_with_context = create_request_and_return_io_and_context(kemal, request)[0] + client_response = HTTP::Client::Response.from_io(io_with_context, decompress: false) + client_response.body.should eq("true") + end +end + +class FilterTest + property modified : String? +end diff --git a/lib/kemal/spec/param_parser_spec.cr b/lib/kemal/spec/param_parser_spec.cr new file mode 100644 index 00000000..d63a2298 --- /dev/null +++ b/lib/kemal/spec/param_parser_spec.cr @@ -0,0 +1,204 @@ +require "./spec_helper" + +describe "ParamParser" do + it "parses query params" do + Route.new "POST", "/" do |env| + hasan = env.params.query["hasan"] + "Hello #{hasan}" + end + request = HTTP::Request.new("POST", "/?hasan=cemal") + query_params = Kemal::ParamParser.new(request).query + query_params["hasan"].should eq "cemal" + end + + it "parses multiple values for query params" do + Route.new "POST", "/" do |env| + hasan = env.params.query["hasan"] + "Hello #{hasan}" + end + request = HTTP::Request.new("POST", "/?hasan=cemal&hasan=lamec") + query_params = Kemal::ParamParser.new(request).query + query_params.fetch_all("hasan").should eq ["cemal", "lamec"] + end + + it "parses url params" do + kemal = Kemal::RouteHandler::INSTANCE + kemal.add_route "POST", "/hello/:hasan" do |env| + "hello #{env.params.url["hasan"]}" + end + request = HTTP::Request.new("POST", "/hello/cemal") + # Radix tree MUST be run to parse url params. + context = create_request_and_return_io_and_context(kemal, request)[1] + url_params = Kemal::ParamParser.new(request, context.route_lookup.params).url + url_params["hasan"].should eq "cemal" + end + + it "decodes url params" do + kemal = Kemal::RouteHandler::INSTANCE + kemal.add_route "POST", "/hello/:email/:money/:spanish" do |env| + email = env.params.url["email"] + money = env.params.url["money"] + spanish = env.params.url["spanish"] + "Hello, #{email}. You have #{money}. The spanish word of the day is #{spanish}." + end + request = HTTP::Request.new("POST", "/hello/sam%2Bspec%40gmail.com/%2419.99/a%C3%B1o") + # Radix tree MUST be run to parse url params. + context = create_request_and_return_io_and_context(kemal, request)[1] + url_params = Kemal::ParamParser.new(request, context.route_lookup.params).url + url_params["email"].should eq "sam+spec@gmail.com" + url_params["money"].should eq "$19.99" + url_params["spanish"].should eq "año" + end + + it "parses request body" do + Route.new "POST", "/" do |env| + name = env.params.query["name"] + age = env.params.query["age"] + hasan = env.params.body["hasan"] + "Hello #{name} #{hasan} #{age}" + end + + request = HTTP::Request.new( + "POST", + "/?hasan=cemal", + body: "name=serdar&age=99", + headers: HTTP::Headers{"Content-Type" => "application/x-www-form-urlencoded"}, + ) + + query_params = Kemal::ParamParser.new(request).query + {"hasan" => "cemal"}.each do |k, v| + query_params[k].should eq(v) + end + + body_params = Kemal::ParamParser.new(request).body + {"name" => "serdar", "age" => "99"}.each do |k, v| + body_params[k].should eq(v) + end + end + + it "parses multiple values in request body" do + Route.new "POST", "/" do |env| + hasan = env.params.body["hasan"] + "Hello #{hasan}" + end + + request = HTTP::Request.new( + "POST", + "/", + body: "hasan=cemal&hasan=lamec", + headers: HTTP::Headers{"Content-Type" => "application/x-www-form-urlencoded"}, + ) + + body_params = Kemal::ParamParser.new(request).body + body_params.fetch_all("hasan").should eq(["cemal", "lamec"]) + end + + context "when content type is application/json" do + it "parses request body" do + Route.new "POST", "/" { } + + request = HTTP::Request.new( + "POST", + "/", + body: "{\"name\": \"Serdar\"}", + headers: HTTP::Headers{"Content-Type" => "application/json"}, + ) + + json_params = Kemal::ParamParser.new(request).json + json_params.should eq({"name" => "Serdar"}) + end + + it "parses request body when passed charset" do + Route.new "POST", "/" { } + + request = HTTP::Request.new( + "POST", + "/", + body: "{\"name\": \"Serdar\"}", + headers: HTTP::Headers{"Content-Type" => "application/json; charset=utf-8"}, + ) + + json_params = Kemal::ParamParser.new(request).json + json_params.should eq({"name" => "Serdar"}) + end + + it "parses request body for array" do + Route.new "POST", "/" { } + + request = HTTP::Request.new( + "POST", + "/", + body: "[1]", + headers: HTTP::Headers{"Content-Type" => "application/json"}, + ) + + json_params = Kemal::ParamParser.new(request).json + json_params.should eq({"_json" => [1]}) + end + + it "parses request body and query params" do + Route.new "POST", "/" { } + + request = HTTP::Request.new( + "POST", + "/?foo=bar", + body: "[1]", + headers: HTTP::Headers{"Content-Type" => "application/json"}, + ) + + query_params = Kemal::ParamParser.new(request).query + {"foo" => "bar"}.each do |k, v| + query_params[k].should eq(v) + end + + json_params = Kemal::ParamParser.new(request).json + json_params.should eq({"_json" => [1]}) + end + + it "handles no request body" do + Route.new "GET", "/" { } + + request = HTTP::Request.new( + "GET", + "/", + headers: HTTP::Headers{"Content-Type" => "application/json"}, + ) + + url_params = Kemal::ParamParser.new(request).url + url_params.should eq({} of String => String) + + query_params = Kemal::ParamParser.new(request).query + query_params.to_s.should eq("") + + body_params = Kemal::ParamParser.new(request).body + body_params.to_s.should eq("") + + json_params = Kemal::ParamParser.new(request).json + json_params.should eq({} of String => Nil | String | Int64 | Float64 | Bool | Hash(String, JSON::Any) | Array(JSON::Any)) + end + end + + context "when content type is incorrect" do + it "does not parse request body" do + Route.new "POST", "/" do |env| + name = env.params.body["name"] + age = env.params.body["age"] + hasan = env.params.query["hasan"] + "Hello #{name} #{hasan} #{age}" + end + + request = HTTP::Request.new( + "POST", + "/?hasan=cemal", + body: "name=serdar&age=99", + headers: HTTP::Headers{"Content-Type" => "text/plain"}, + ) + + query_params = Kemal::ParamParser.new(request).query + query_params["hasan"].should eq("cemal") + + body_params = Kemal::ParamParser.new(request).body + body_params.to_s.should eq("") + end + end +end diff --git a/lib/kemal/spec/route_handler_spec.cr b/lib/kemal/spec/route_handler_spec.cr new file mode 100644 index 00000000..a6db5e12 --- /dev/null +++ b/lib/kemal/spec/route_handler_spec.cr @@ -0,0 +1,123 @@ +require "./spec_helper" + +describe "Kemal::RouteHandler" do + it "routes" do + get "/" do + "hello" + end + request = HTTP::Request.new("GET", "/") + client_response = call_request_on_app(request) + client_response.body.should eq("hello") + end + + it "routes should only return strings" do + get "/" do + 100 + end + request = HTTP::Request.new("GET", "/") + client_response = call_request_on_app(request) + client_response.body.should eq("") + end + + it "routes request with query string" do + get "/" do |env| + "hello #{env.params.query["message"]}" + end + request = HTTP::Request.new("GET", "/?message=world") + client_response = call_request_on_app(request) + client_response.body.should eq("hello world") + end + + it "routes request with multiple query strings" do + get "/" do |env| + "hello #{env.params.query["message"]} time #{env.params.query["time"]}" + end + request = HTTP::Request.new("GET", "/?message=world&time=now") + client_response = call_request_on_app(request) + client_response.body.should eq("hello world time now") + end + + it "route parameter has more precedence than query string arguments" do + get "/:message" do |env| + "hello #{env.params.url["message"]}" + end + request = HTTP::Request.new("GET", "/world?message=coco") + client_response = call_request_on_app(request) + client_response.body.should eq("hello world") + end + + it "parses simple JSON body" do + post "/" do |env| + name = env.params.json["name"] + age = env.params.json["age"] + "Hello #{name} Age #{age}" + end + + json_payload = {"name": "Serdar", "age": 26} + request = HTTP::Request.new( + "POST", + "/", + body: json_payload.to_json, + headers: HTTP::Headers{"Content-Type" => "application/json"}, + ) + client_response = call_request_on_app(request) + client_response.body.should eq("Hello Serdar Age 26") + end + + it "parses JSON with string array" do + post "/" do |env| + skills = env.params.json["skills"].as(Array) + "Skills #{skills.each.join(',')}" + end + + json_payload = {"skills": ["ruby", "crystal"]} + request = HTTP::Request.new( + "POST", + "/", + body: json_payload.to_json, + headers: HTTP::Headers{"Content-Type" => "application/json"}, + ) + client_response = call_request_on_app(request) + client_response.body.should eq("Skills ruby,crystal") + end + + it "parses JSON with json object array" do + post "/" do |env| + skills = env.params.json["skills"].as(Array) + skills_from_languages = skills.map do |skill| + skill["language"] + end + "Skills #{skills_from_languages.each.join(',')}" + end + + json_payload = {"skills": [{"language": "ruby"}, {"language": "crystal"}]} + request = HTTP::Request.new( + "POST", + "/", + body: json_payload.to_json, + headers: HTTP::Headers{"Content-Type" => "application/json"}, + ) + + client_response = call_request_on_app(request) + client_response.body.should eq("Skills ruby,crystal") + end + + it "can process HTTP HEAD requests for defined GET routes" do + get "/" do + "Hello World from GET" + end + request = HTTP::Request.new("HEAD", "/") + client_response = call_request_on_app(request) + client_response.status_code.should eq(200) + end + + it "redirects user to provided url" do + get "/" do |env| + env.redirect "/login" + end + request = HTTP::Request.new("GET", "/") + client_response = call_request_on_app(request) + client_response.status_code.should eq(302) + client_response.headers.has_key?("Location").should eq(true) + end +end diff --git a/lib/kemal/spec/route_spec.cr b/lib/kemal/spec/route_spec.cr new file mode 100644 index 00000000..7634d51d --- /dev/null +++ b/lib/kemal/spec/route_spec.cr @@ -0,0 +1,25 @@ +require "./spec_helper" + +describe "Route" do + describe "match?" do + it "matches the correct route" do + get "/route1" do + "Route 1" + end + get "/route2" do + "Route 2" + end + request = HTTP::Request.new("GET", "/route2") + client_response = call_request_on_app(request) + client_response.body.should eq("Route 2") + end + + it "doesn't allow a route declaration start without /" do + expect_raises Kemal::Exceptions::InvalidPathStartException, "Route declaration get \"route\" needs to start with '/', should be get \"/route\"" do + get "route" do + "Route 1" + end + end + end + end +end diff --git a/lib/kemal/spec/run_spec.cr b/lib/kemal/spec/run_spec.cr new file mode 100644 index 00000000..25534552 --- /dev/null +++ b/lib/kemal/spec/run_spec.cr @@ -0,0 +1,48 @@ +require "./spec_helper" + +private def run(code) + code = <<-CR + require "./src/kemal" + #{code} + CR + String.build do |stdout| + stderr = String.build do |stderr| + Process.new("crystal", ["eval"], input: IO::Memory.new(code), output: stdout, error: stderr).wait + end + unless stderr.empty? + fail(stderr) + end + end +end + +describe "Run" do + it "runs a code block after starting" do + run(<<-CR).should eq "started\nstopped\n" + Kemal.config.env = "test" + Kemal.run do + puts "started" + Kemal.stop + puts "stopped" + end + CR + end + + it "runs without a block being specified" do + run(<<-CR).should eq "[test] Kemal is ready to lead at http://0.0.0.0:3000\ntrue\n" + Kemal.config.env = "test" + Kemal.run + puts Kemal.config.running + CR + end + + it "allows custom HTTP::Server bind" do + run(<<-CR).should eq "[test] Kemal is ready to lead at http://127.0.0.1:3000, http://0.0.0.0:3001\n" + Kemal.config.env = "test" + Kemal.run do |config| + server = config.server.not_nil! + server.bind_tcp "127.0.0.1", 3000, reuse_port: true + server.bind_tcp "0.0.0.0", 3001, reuse_port: true + end + CR + end +end diff --git a/lib/kemal/spec/spec_helper.cr b/lib/kemal/spec/spec_helper.cr new file mode 100644 index 00000000..0bc127ad --- /dev/null +++ b/lib/kemal/spec/spec_helper.cr @@ -0,0 +1,88 @@ +require "spec" +require "../src/*" + +include Kemal + +class CustomLogHandler < Kemal::BaseLogHandler + def call(env) + call_next env + end + + def write(message) + end +end + +class TestContextStorageType + property id + @id = 1 + + def to_s + @id + end +end + +class AnotherContextStorageType + property name + @name = "kemal-context" +end + +add_context_storage_type(TestContextStorageType) +add_context_storage_type(AnotherContextStorageType) + +def create_request_and_return_io_and_context(handler, request) + io = IO::Memory.new + response = HTTP::Server::Response.new(io) + context = HTTP::Server::Context.new(request, response) + handler.call(context) + response.close + io.rewind + {io, context} +end + +def create_ws_request_and_return_io_and_context(handler, request) + io = IO::Memory.new + response = HTTP::Server::Response.new(io) + context = HTTP::Server::Context.new(request, response) + begin + handler.call context + rescue IO::Error + # Raises because the IO::Memory is empty + end + io.rewind + {io, context} +end + +def call_request_on_app(request) + io = IO::Memory.new + response = HTTP::Server::Response.new(io) + context = HTTP::Server::Context.new(request, response) + main_handler = build_main_handler + main_handler.call context + response.close + io.rewind + HTTP::Client::Response.from_io(io, decompress: false) +end + +def build_main_handler + Kemal.config.setup + main_handler = Kemal.config.handlers.first + current_handler = main_handler + Kemal.config.handlers.each do |handler| + current_handler.next = handler + current_handler = handler + end + main_handler +end + +Spec.before_each do + config = Kemal.config + config.env = "development" + config.logging = false +end + +Spec.after_each do + Kemal.config.clear + Kemal::RouteHandler::INSTANCE.routes = Radix::Tree(Route).new + Kemal::RouteHandler::INSTANCE.cached_routes = Hash(String, Radix::Result(Route)).new + Kemal::WebSocketHandler::INSTANCE.routes = Radix::Tree(WebSocket).new +end diff --git a/lib/kemal/spec/static/dir/bigger.txt b/lib/kemal/spec/static/dir/bigger.txt new file mode 100644 index 00000000..36281ae8 --- /dev/null +++ b/lib/kemal/spec/static/dir/bigger.txt @@ -0,0 +1,9 @@ +Lorem ipsum dolor sit amet, consectetur adipiscing elit. Suspendisse posuere cursus consectetur. Donec mauris lorem, sodales a eros a, ultricies convallis ante. Quisque elementum lacus purus, sagittis mollis justo dignissim ac. Suspendisse potenti. Cras non mauris accumsan mi porttitor congue. Quisque posuere aliquam tellus sit amet ultrices. Sed at tortor sed libero fringilla luctus vitae quis magna. In maximus congue felis, et porta tortor egestas sed. Phasellus orci eros, finibus sed ipsum eget, euismod bibendum nisl. Etiam ultrices facilisis diam in gravida. Praesent lobortis leo vitae aliquet volutpat. Praesent vel blandit risus. In suscipit eget nunc at ultrices. Proin dapibus feugiat diam ut tincidunt. Donec lectus diam, ornare ut consequat nec, gravida sit amet metus. + +Nunc a viverra urna, quis ullamcorper augue. Morbi posuere auctor nibh, tempor luctus massa mollis laoreet. Pellentesque sagittis leo eu felis interdum finibus. Pellentesque porttitor lobortis arcu, eu mollis dui iaculis nec. Vestibulum sit amet sodales erat. Nullam quis mi massa. Suspendisse sit amet elit auctor, feugiat ipsum a, placerat metus. Vestibulum quis felis a lectus blandit aliquam. Nam consectetur iaculis nulla. Mauris sit amet condimentum erat, in vestibulum dui. Nullam nec mattis tortor, non viverra nunc. Proin eget congue augue. Cum sociis natoque penatibus et magnis dis parturient montes, nascetur ridiculus mus. Sed ut hendrerit nulla. Etiam cursus sagittis metus, et feugiat ligula molestie sit amet. Aliquam laoreet auctor sagittis. + +Aliquam tempor urna non consectetur tincidunt. Maecenas porttitor augue diam, ac lobortis nulla suscipit eget. Ut quis lacus facilisis, euismod lacus non, ullamcorper urna. Cras pretium fringilla pharetra. Praesent sed nunc at elit vulputate elementum. Suspendisse ac molestie nunc, sit amet consectetur nunc. Cras placerat ligula tortor, non bibendum massa tempus ut. Etiam eros erat, gravida id felis eget, congue suscipit ipsum. Sed condimentum erat at facilisis dictum. Cras venenatis vitae turpis vitae sagittis. Proin id posuere est, non ornare sem. Donec vitae sollicitudin dolor, a pulvinar ex. Integer porta velit lectus, et imperdiet enim commodo a. + +Donec sit amet ipsum tempus, tincidunt neque eget, luctus massa. Praesent vel nulla pretium, bibendum enim a, pulvinar enim. Vestibulum non libero eu est dignissim cursus. Nullam commodo tellus imperdiet feugiat placerat. Sed sed dolor ut nibh blandit maximus ac eget neque. Ut sit amet augue maximus, lacinia eros non, faucibus eros. Suspendisse ac bibendum libero, eu lobortis nulla. Mauris arcu nulla, tempus eu varius eu, bibendum at nibh. Donec id libero consequat, volutpat ex vitae, molestie velit. Aliquam aliquam sem ac arcu pellentesque, placerat bibendum enim dapibus. Duis consectetur ligula non placerat euismod. + +Vestibulum ante ipsum primis in faucibus orci luctus et ultrices posuere cubilia Curae; Proin commodo ullamcorper venenatis. Cras ac lorem sit amet augue varius convallis. Lorem ipsum dolor sit amet, consectetur adipiscing elit. Mauris dolor nisi, efficitur id aliquet ut, ultricies sed elit. Proin ultricies turpis dolor, in auctor velit aliquet nec. Praesent vehicula aliquam viverra. Suspendisse potenti. Donec aliquet iaculis ultricies. Proin dignissim vitae nisl at rutrum. \ No newline at end of file diff --git a/lib/kemal/spec/static/dir/index.html b/lib/kemal/spec/static/dir/index.html new file mode 100644 index 00000000..32d977f9 --- /dev/null +++ b/lib/kemal/spec/static/dir/index.html @@ -0,0 +1,12 @@ + + + + + title + + + + + + + \ No newline at end of file diff --git a/lib/kemal/spec/static/dir/test.txt b/lib/kemal/spec/static/dir/test.txt new file mode 100644 index 00000000..9db7df02 --- /dev/null +++ b/lib/kemal/spec/static/dir/test.txt @@ -0,0 +1,2 @@ +hello +world \ No newline at end of file diff --git a/lib/kemal/spec/static_file_handler_spec.cr b/lib/kemal/spec/static_file_handler_spec.cr new file mode 100644 index 00000000..1aac161b --- /dev/null +++ b/lib/kemal/spec/static_file_handler_spec.cr @@ -0,0 +1,153 @@ +require "./spec_helper" + +private def handle(request, fallthrough = true) + io = IO::Memory.new + response = HTTP::Server::Response.new(io) + context = HTTP::Server::Context.new(request, response) + handler = Kemal::StaticFileHandler.new "#{__DIR__}/static", fallthrough + handler.call context + response.close + io.rewind + HTTP::Client::Response.from_io(io) +end + +describe Kemal::StaticFileHandler do + file = File.open "#{__DIR__}/static/dir/test.txt" + file_size = file.size + + it "should serve a file with content type and etag" do + response = handle HTTP::Request.new("GET", "/dir/test.txt") + response.status_code.should eq(200) + response.headers["Content-Type"].should eq "text/plain" + response.headers["Etag"].should contain "W/\"" + response.body.should eq(File.read("#{__DIR__}/static/dir/test.txt")) + end + + it "should respond with 304 if file has not changed" do + response = handle HTTP::Request.new("GET", "/dir/test.txt") + response.status_code.should eq(200) + etag = response.headers["Etag"] + + headers = HTTP::Headers{"If-None-Match" => etag} + response = handle HTTP::Request.new("GET", "/dir/test.txt", headers) + response.headers["Content-Type"]?.should be_nil + response.status_code.should eq(304) + response.body.should eq "" + end + + it "should not list directory's entries" do + serve_static({"gzip" => true, "dir_listing" => false}) + response = handle HTTP::Request.new("GET", "/dir/") + response.status_code.should eq(404) + end + + it "should list directory's entries when config is set" do + serve_static({"gzip" => true, "dir_listing" => true}) + response = handle HTTP::Request.new("GET", "/dir/") + response.status_code.should eq(200) + response.body.should match(/test.txt/) + end + + it "should gzip a file if config is true, headers accept gzip and file is > 880 bytes" do + serve_static({"gzip" => true, "dir_listing" => true}) + headers = HTTP::Headers{"Accept-Encoding" => "gzip, deflate, sdch, br"} + response = handle HTTP::Request.new("GET", "/dir/bigger.txt", headers) + response.status_code.should eq(200) + response.headers["Content-Encoding"].should eq "gzip" + end + + it "should not gzip a file if config is true, headers accept gzip and file is < 880 bytes" do + serve_static({"gzip" => true, "dir_listing" => true}) + headers = HTTP::Headers{"Accept-Encoding" => "gzip, deflate, sdch, br"} + response = handle HTTP::Request.new("GET", "/dir/test.txt", headers) + response.status_code.should eq(200) + response.headers["Content-Encoding"]?.should be_nil + end + + it "should not gzip a file if config is false, headers accept gzip and file is > 880 bytes" do + serve_static({"gzip" => false, "dir_listing" => true}) + headers = HTTP::Headers{"Accept-Encoding" => "gzip, deflate, sdch, br"} + response = handle HTTP::Request.new("GET", "/dir/bigger.txt", headers) + response.status_code.should eq(200) + response.headers["Content-Encoding"]?.should be_nil + end + + it "should not serve a not found file" do + response = handle HTTP::Request.new("GET", "/not_found_file.txt") + response.status_code.should eq(404) + end + + it "should not serve a not found directory" do + response = handle HTTP::Request.new("GET", "/not_found_dir/") + response.status_code.should eq(404) + end + + it "should not serve a file as directory" do + response = handle HTTP::Request.new("GET", "/dir/test.txt/") + response.status_code.should eq(404) + end + + it "should handle only GET and HEAD method" do + %w(GET HEAD).each do |method| + response = handle HTTP::Request.new(method, "/dir/test.txt") + response.status_code.should eq(200) + end + + %w(POST PUT DELETE).each do |method| + response = handle HTTP::Request.new(method, "/dir/test.txt") + response.status_code.should eq(404) + response = handle HTTP::Request.new(method, "/dir/test.txt"), false + response.status_code.should eq(405) + response.headers["Allow"].should eq("GET, HEAD") + end + end + + it "should send part of files when requested (RFC7233)" do + %w(POST PUT DELETE HEAD).each do |method| + headers = HTTP::Headers{"Range" => "0-100"} + response = handle HTTP::Request.new(method, "/dir/test.txt", headers) + response.status_code.should_not eq(206) + response.headers.has_key?("Content-Range").should eq(false) + end + + %w(GET).each do |method| + headers = HTTP::Headers{"Range" => "0-100"} + response = handle HTTP::Request.new(method, "/dir/test.txt", headers) + response.status_code.should eq(206 || 200) + if response.status_code == 206 + response.headers.has_key?("Content-Range").should eq true + match = response.headers["Content-Range"].match(/bytes (\d+)-(\d+)\/(\d+)/) + match.should_not be_nil + if match + start_range = match[1].to_i { 0 } + end_range = match[2].to_i { 0 } + range_size = match[3].to_i { 0 } + + range_size.should eq file_size + (end_range < file_size).should eq true + (start_range < end_range).should eq true + end + end + end + end + + it "should handle setting custom headers" do + headers = Proc(HTTP::Server::Response, String, File::Info, Void).new do |response, path, stat| + if path =~ /\.html$/ + response.headers.add("Access-Control-Allow-Origin", "*") + end + response.headers.add("Content-Size", stat.size.to_s) + end + + static_headers(&headers) + + response = handle HTTP::Request.new("GET", "/dir/test.txt") + response.headers.has_key?("Access-Control-Allow-Origin").should be_false + response.headers["Content-Size"].should eq( + File.info("#{__DIR__}/static/dir/test.txt").size.to_s + ) + + response = handle HTTP::Request.new("GET", "/dir/index.html") + response.headers["Access-Control-Allow-Origin"].should eq("*") + end +end diff --git a/lib/kemal/spec/view_spec.cr b/lib/kemal/spec/view_spec.cr new file mode 100644 index 00000000..d09f4de4 --- /dev/null +++ b/lib/kemal/spec/view_spec.cr @@ -0,0 +1,62 @@ +require "./spec_helper" + +macro render_with_base_and_layout(filename) + render "spec/asset/#{{{filename}}}", "spec/asset/layout.ecr" +end + +describe "Views" do + it "renders file" do + get "/view/:name" do |env| + name = env.params.url["name"] + render "spec/asset/hello.ecr" + end + request = HTTP::Request.new("GET", "/view/world") + client_response = call_request_on_app(request) + client_response.body.should contain("Hello world") + end + + it "renders file with dynamic variables" do + get "/view/:name" do |env| + name = env.params.url["name"] + render_with_base_and_layout "hello.ecr" + end + request = HTTP::Request.new("GET", "/view/world") + client_response = call_request_on_app(request) + client_response.body.should contain("Hello world") + end + + it "renders layout" do + get "/view/:name" do |env| + name = env.params.url["name"] + render "spec/asset/hello.ecr", "spec/asset/layout.ecr" + end + request = HTTP::Request.new("GET", "/view/world") + client_response = call_request_on_app(request) + client_response.body.should contain("Hello world") + end + + it "renders layout with variables" do + get "/view/:name" do |env| + name = env.params.url["name"] + var1 = "serdar" + var2 = "kemal" + render "spec/asset/hello_with_content_for.ecr", "spec/asset/layout_with_yield_and_vars.ecr" + end + request = HTTP::Request.new("GET", "/view/world") + client_response = call_request_on_app(request) + client_response.body.should contain("Hello world") + client_response.body.should contain("serdar") + client_response.body.should contain("kemal") + end + + it "renders layout with content_for" do + get "/view/:name" do |env| + name = env.params.url["name"] + render "spec/asset/hello_with_content_for.ecr", "spec/asset/layout_with_yield.ecr" + end + request = HTTP::Request.new("GET", "/view/world") + client_response = call_request_on_app(request) + client_response.body.should contain("Hello world") + client_response.body.should contain("

Hello from otherside

") + end +end diff --git a/lib/kemal/spec/websocket_handler_spec.cr b/lib/kemal/spec/websocket_handler_spec.cr new file mode 100644 index 00000000..bc02d3c2 --- /dev/null +++ b/lib/kemal/spec/websocket_handler_spec.cr @@ -0,0 +1,68 @@ +require "./spec_helper" + +describe "Kemal::WebSocketHandler" do + it "doesn't match on wrong route" do + handler = Kemal::WebSocketHandler::INSTANCE + handler.next = Kemal::RouteHandler::INSTANCE + ws "/" { } + headers = HTTP::Headers{ + "Upgrade" => "websocket", + "Connection" => "Upgrade", + "Sec-WebSocket-Key" => "dGhlIHNhbXBsZSBub25jZQ==", + } + request = HTTP::Request.new("GET", "/asd", headers) + io = IO::Memory.new + response = HTTP::Server::Response.new(io) + context = HTTP::Server::Context.new(request, response) + + expect_raises(Kemal::Exceptions::RouteNotFound) do + handler.call context + end + end + + it "matches on given route" do + handler = Kemal::WebSocketHandler::INSTANCE + ws "/" { |socket| socket.send("Match") } + ws "/no_match" { |socket| socket.send "No Match" } + headers = HTTP::Headers{ + "Upgrade" => "websocket", + "Connection" => "Upgrade", + "Sec-WebSocket-Key" => "dGhlIHNhbXBsZSBub25jZQ==", + "Sec-WebSocket-Version" => "13", + } + request = HTTP::Request.new("GET", "/", headers) + + io_with_context = create_ws_request_and_return_io_and_context(handler, request)[0] + io_with_context.to_s.should eq("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\r\n\x81\u0005Match") + end + + it "fetches named url parameters" do + handler = Kemal::WebSocketHandler::INSTANCE + ws "/:id" { |_, c| c.ws_route_lookup.params["id"] } + headers = HTTP::Headers{ + "Upgrade" => "websocket", + "Connection" => "Upgrade", + "Sec-WebSocket-Key" => "dGhlIHNhbXBsZSBub25jZQ==", + "Sec-WebSocket-Version" => "13", + } + request = HTTP::Request.new("GET", "/1234", headers) + io_with_context = create_ws_request_and_return_io_and_context(handler, request)[0] + io_with_context.to_s.should eq("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\r\n") + end + + it "matches correct verb" do + handler = Kemal::WebSocketHandler::INSTANCE + handler.next = Kemal::RouteHandler::INSTANCE + ws "/" { } + get "/" { "get" } + request = HTTP::Request.new("GET", "/") + io = IO::Memory.new + response = HTTP::Server::Response.new(io) + context = HTTP::Server::Context.new(request, response) + handler.call(context) + response.close + io.rewind + client_response = HTTP::Client::Response.from_io(io, decompress: false) + client_response.body.should eq("get") + end +end diff --git a/lib/kemal/src/kemal.cr b/lib/kemal/src/kemal.cr new file mode 100644 index 00000000..009337e2 --- /dev/null +++ b/lib/kemal/src/kemal.cr @@ -0,0 +1,98 @@ +require "http" +require "json" +require "uri" +require "./kemal/*" +require "./kemal/ext/*" +require "./kemal/helpers/*" + +module Kemal + # Overload of `self.run` with the default startup logging. + def self.run(port : Int32?, args = ARGV) + self.run(port, args) { } + end + + # Overload of `self.run` without port. + def self.run(args = ARGV) + self.run(nil, args: args) + end + + # Overload of `self.run` to allow just a block. + def self.run(args = ARGV, &block) + self.run(nil, args: args, &block) + end + + # The command to run a `Kemal` application. + # + # If *port* is not given Kemal will use `Kemal::Config#port` + # + # To use custom command line arguments, set args to nil + # + def self.run(port : Int32? = nil, args = ARGV, &block) + Kemal::CLI.new args + config = Kemal.config + config.setup + config.port = port if port + + # Test environment doesn't need to have signal trap and logging. + if config.env != "test" + setup_404 + setup_trap_signal + end + + server = config.server ||= HTTP::Server.new(config.handlers) + + config.running = true + + yield config + + # Abort if block called `Kemal.stop` + return unless config.running + + unless server.each_address { |_| break true } + {% if flag?(:without_openssl) %} + server.bind_tcp(config.host_binding, config.port) + {% else %} + if ssl = config.ssl + server.bind_tls(config.host_binding, config.port, ssl) + else + server.bind_tcp(config.host_binding, config.port) + end + {% end %} + end + + display_startup_message(config, server) + + server.listen unless config.env == "test" + end + + def self.display_startup_message(config, server) + addresses = server.addresses.map { |address| "#{config.scheme}://#{address}" }.join ", " + log "[#{config.env}] Kemal is ready to lead at #{addresses}" + end + + def self.stop + raise "Kemal is already stopped." if !config.running + if server = config.server + server.close unless server.closed? + config.running = false + else + raise "Kemal.config.server is not set. Please use Kemal.run to set the server." + end + end + + private def self.setup_404 + unless Kemal.config.error_handlers.has_key?(404) + error 404 do + render_404 + end + end + end + + private def self.setup_trap_signal + Signal::INT.trap do + log "Kemal is going to take a rest!" if Kemal.config.shutdown_message + Kemal.stop + exit + end + end +end diff --git a/lib/kemal/src/kemal/base_log_handler.cr b/lib/kemal/src/kemal/base_log_handler.cr new file mode 100644 index 00000000..37ee980b --- /dev/null +++ b/lib/kemal/src/kemal/base_log_handler.cr @@ -0,0 +1,9 @@ +module Kemal + # All loggers must inherit from `Kemal::BaseLogHandler`. + abstract class BaseLogHandler + include HTTP::Handler + + abstract def call(context : HTTP::Server::Context) + abstract def write(message : String) + end +end diff --git a/lib/kemal/src/kemal/cli.cr b/lib/kemal/src/kemal/cli.cr new file mode 100644 index 00000000..656a4e69 --- /dev/null +++ b/lib/kemal/src/kemal/cli.cr @@ -0,0 +1,62 @@ +require "option_parser" + +module Kemal + # Handles all the initialization from the command line. + class CLI + def initialize(args) + @ssl_enabled = false + @key_file = "" + @cert_file = "" + @config = Kemal.config + read_env + if args + parse args + end + configure_ssl + end + + private def parse(args : Array(String)) + OptionParser.parse args do |opts| + opts.on("-b HOST", "--bind HOST", "Host to bind (defaults to 0.0.0.0)") do |host_binding| + @config.host_binding = host_binding + end + opts.on("-p PORT", "--port PORT", "Port to listen for connections (defaults to 3000)") do |opt_port| + @config.port = opt_port.to_i + end + opts.on("-s", "--ssl", "Enables SSL") do + @ssl_enabled = true + end + opts.on("--ssl-key-file FILE", "SSL key file") do |key_file| + @key_file = key_file + end + opts.on("--ssl-cert-file FILE", "SSL certificate file") do |cert_file| + @cert_file = cert_file + end + opts.on("-h", "--help", "Shows this help") do + puts opts + exit 0 + end + @config.extra_options.try &.call(opts) + end + end + + private def configure_ssl + {% if !flag?(:without_openssl) %} + if @ssl_enabled + abort "SSL Key Not Found" if !@key_file + abort "SSL Certificate Not Found" if !@cert_file + ssl = Kemal::SSL.new + ssl.key_file = @key_file.not_nil! + ssl.cert_file = @cert_file.not_nil! + Kemal.config.ssl = ssl.context + end + {% end %} + end + + private def read_env + if kemal_env = ENV["KEMAL_ENV"]? + @config.env = kemal_env + end + end + end +end diff --git a/lib/kemal/src/kemal/config.cr b/lib/kemal/src/kemal/config.cr new file mode 100644 index 00000000..04bbdd7e --- /dev/null +++ b/lib/kemal/src/kemal/config.cr @@ -0,0 +1,168 @@ +module Kemal + VERSION = {{ `shards version #{__DIR__}`.chomp.stringify }} + + # Stores all the configuration options for a Kemal application. + # It's a singleton and you can access it like. + # + # ``` + # Kemal.config + # ``` + class Config + INSTANCE = Config.new + HANDLERS = [] of HTTP::Handler + CUSTOM_HANDLERS = [] of Tuple(Nil | Int32, HTTP::Handler) + FILTER_HANDLERS = [] of HTTP::Handler + ERROR_HANDLERS = {} of Int32 => HTTP::Server::Context, Exception -> String + + {% if flag?(:without_openssl) %} + @ssl : Bool? + {% else %} + @ssl : OpenSSL::SSL::Context::Server? + {% end %} + + property host_binding, ssl, port, env, public_folder, logging, running + property always_rescue, server : HTTP::Server?, extra_options, shutdown_message + property serve_static : (Bool | Hash(String, Bool)) + property static_headers : (HTTP::Server::Response, String, File::Info -> Void)? + property powered_by_header : Bool = true + + def initialize + @host_binding = "0.0.0.0" + @port = 3000 + @env = "development" + @serve_static = {"dir_listing" => false, "gzip" => true} + @public_folder = "./public" + @logging = true + @logger = nil + @error_handler = nil + @always_rescue = true + @router_included = false + @default_handlers_setup = false + @running = false + @shutdown_message = true + @handler_position = 0 + end + + def logger + @logger.not_nil! + end + + def logger=(logger : Kemal::BaseLogHandler) + @logger = logger + end + + def scheme + ssl ? "https" : "http" + end + + def clear + @powered_by_header = true + @router_included = false + @handler_position = 0 + @default_handlers_setup = false + HANDLERS.clear + CUSTOM_HANDLERS.clear + FILTER_HANDLERS.clear + ERROR_HANDLERS.clear + end + + def handlers + HANDLERS + end + + def handlers=(handlers : Array(HTTP::Handler)) + clear + HANDLERS.replace(handlers) + end + + def add_handler(handler : HTTP::Handler) + CUSTOM_HANDLERS << {nil, handler} + end + + def add_handler(handler : HTTP::Handler, position : Int32) + CUSTOM_HANDLERS << {position, handler} + end + + def add_filter_handler(handler : HTTP::Handler) + FILTER_HANDLERS << handler + end + + def error_handlers + ERROR_HANDLERS + end + + def add_error_handler(status_code : Int32, &handler : HTTP::Server::Context, Exception -> _) + ERROR_HANDLERS[status_code] = ->(context : HTTP::Server::Context, error : Exception) { handler.call(context, error).to_s } + end + + def extra_options(&@extra_options : OptionParser ->) + end + + def setup + unless @default_handlers_setup && @router_included + setup_init_handler + setup_log_handler + setup_error_handler + setup_static_file_handler + setup_custom_handlers + setup_filter_handlers + @default_handlers_setup = true + @router_included = true + HANDLERS.insert(HANDLERS.size, Kemal::WebSocketHandler::INSTANCE) + HANDLERS.insert(HANDLERS.size, Kemal::RouteHandler::INSTANCE) + end + end + + private def setup_init_handler + HANDLERS.insert(@handler_position, Kemal::InitHandler::INSTANCE) + @handler_position += 1 + end + + private def setup_log_handler + @logger ||= if @logging + Kemal::LogHandler.new + else + Kemal::NullLogHandler.new + end + HANDLERS.insert(@handler_position, @logger.not_nil!) + @handler_position += 1 + end + + private def setup_error_handler + if @always_rescue + @error_handler ||= Kemal::ExceptionHandler.new + HANDLERS.insert(@handler_position, @error_handler.not_nil!) + @handler_position += 1 + end + end + + private def setup_static_file_handler + if @serve_static.is_a?(Hash) + HANDLERS.insert(@handler_position, Kemal::StaticFileHandler.new(@public_folder)) + @handler_position += 1 + end + end + + private def setup_custom_handlers + CUSTOM_HANDLERS.each do |ch0, ch1| + position = ch0 + HANDLERS.insert (position || @handler_position), ch1 + @handler_position += 1 + end + end + + private def setup_filter_handlers + FILTER_HANDLERS.each do |h| + HANDLERS.insert(@handler_position, h) + end + end + end + + def self.config + yield Config::INSTANCE + end + + def self.config + Config::INSTANCE + end +end diff --git a/lib/kemal/src/kemal/dsl.cr b/lib/kemal/src/kemal/dsl.cr new file mode 100644 index 00000000..15b37424 --- /dev/null +++ b/lib/kemal/src/kemal/dsl.cr @@ -0,0 +1,37 @@ +# Kemal DSL is defined here and it's baked into global scope. +# +# The DSL currently consists of: +# +# - get post put patch delete options +# - WebSocket(ws) +# - before_* +# - error +HTTP_METHODS = %w(get post put patch delete options) +FILTER_METHODS = %w(get post put patch delete options all) + +{% for method in HTTP_METHODS %} + def {{method.id}}(path : String, &block : HTTP::Server::Context -> _) + raise Kemal::Exceptions::InvalidPathStartException.new({{method}}, path) unless Kemal::Utils.path_starts_with_slash?(path) + Kemal::RouteHandler::INSTANCE.add_route({{method}}.upcase, path, &block) + end +{% end %} + +def ws(path : String, &block : HTTP::WebSocket, HTTP::Server::Context -> Void) + raise Kemal::Exceptions::InvalidPathStartException.new("ws", path) unless Kemal::Utils.path_starts_with_slash?(path) + Kemal::WebSocketHandler::INSTANCE.add_route path, &block +end + +def error(status_code : Int32, &block : HTTP::Server::Context, Exception -> _) + Kemal.config.add_error_handler status_code, &block +end + +# All the helper methods available are: +# - before_all, before_get, before_post, before_put, before_patch, before_delete, before_options +# - after_all, after_get, after_post, after_put, after_patch, after_delete, after_options +{% for type in ["before", "after"] %} + {% for method in FILTER_METHODS %} + def {{type.id}}_{{method.id}}(path : String = "*", &block : HTTP::Server::Context -> _) + Kemal::FilterHandler::INSTANCE.{{type.id}}({{method}}.upcase, path, &block) + end + {% end %} +{% end %} diff --git a/lib/kemal/src/kemal/exception_handler.cr b/lib/kemal/src/kemal/exception_handler.cr new file mode 100644 index 00000000..eee6eecd --- /dev/null +++ b/lib/kemal/src/kemal/exception_handler.cr @@ -0,0 +1,30 @@ +module Kemal + # Handles all the exceptions, including 404, custom errors and 500. + class ExceptionHandler + include HTTP::Handler + INSTANCE = new + + def call(context : HTTP::Server::Context) + call_next(context) + rescue ex : Kemal::Exceptions::RouteNotFound + call_exception_with_status_code(context, ex, 404) + rescue ex : Kemal::Exceptions::CustomException + call_exception_with_status_code(context, ex, context.response.status_code) + rescue ex : Exception + log("Exception: #{ex.inspect_with_backtrace}") + return call_exception_with_status_code(context, ex, 500) if Kemal.config.error_handlers.has_key?(500) + verbosity = Kemal.config.env == "production" ? false : true + render_500(context, ex, verbosity) + end + + private def call_exception_with_status_code(context : HTTP::Server::Context, exception : Exception, status_code : Int32) + return if context.response.closed? + if !Kemal.config.error_handlers.empty? && Kemal.config.error_handlers.has_key?(status_code) + context.response.content_type = "text/html" unless context.response.headers.has_key?("Content-Type") + context.response.status_code = status_code + context.response.print Kemal.config.error_handlers[status_code].call(context, exception) + context + end + end + end +end diff --git a/lib/kemal/src/kemal/ext/context.cr b/lib/kemal/src/kemal/ext/context.cr new file mode 100644 index 00000000..f9a12caf --- /dev/null +++ b/lib/kemal/src/kemal/ext/context.cr @@ -0,0 +1,61 @@ +# `HTTP::Server::Context` is the class which holds `HTTP::Request` and +# `HTTP::Server::Response` alongside with information such as request params, +# request/response content_type, session data and alike. +# +# Instances of this class are passed to an `HTTP::Server` handler. +class HTTP::Server + class Context + # :nodoc: + STORE_MAPPINGS = [Nil, String, Int32, Int64, Float64, Bool] + + macro finished + alias StoreTypes = Union({{ *STORE_MAPPINGS }}) + @store = {} of String => StoreTypes + end + + def params + @params ||= Kemal::ParamParser.new(@request, route_lookup.params) + end + + def redirect(url : String, status_code : Int32 = 302) + @response.headers.add "Location", url + @response.status_code = status_code + end + + def route + route_lookup.payload + end + + def websocket + ws_route_lookup.payload + end + + def route_lookup + Kemal::RouteHandler::INSTANCE.lookup_route(@request.method.as(String), @request.path) + end + + def route_found? + route_lookup.found? + end + + def ws_route_lookup + Kemal::WebSocketHandler::INSTANCE.lookup_ws_route(@request.path) + end + + def ws_route_found? + ws_route_lookup.found? + end + + def get(name : String) + @store[name] + end + + def set(name : String, value : StoreTypes) + @store[name] = value + end + + def get?(name : String) + @store[name]? + end + end +end diff --git a/lib/kemal/src/kemal/ext/response.cr b/lib/kemal/src/kemal/ext/response.cr new file mode 100644 index 00000000..1f0a8fdb --- /dev/null +++ b/lib/kemal/src/kemal/ext/response.cr @@ -0,0 +1,13 @@ +class HTTP::Server::Response + class Output + def close + unless response.wrote_headers? && !response.headers.has_key?("Content-Range") + response.content_length = @out_count + end + + ensure_headers_written + + super + end + end +end diff --git a/lib/kemal/src/kemal/file_upload.cr b/lib/kemal/src/kemal/file_upload.cr new file mode 100644 index 00000000..30eb26aa --- /dev/null +++ b/lib/kemal/src/kemal/file_upload.cr @@ -0,0 +1,24 @@ +module Kemal + struct FileUpload + getter tempfile : File + getter filename : String? + getter headers : HTTP::Headers + getter creation_time : Time? + getter modification_time : Time? + getter read_time : Time? + getter size : UInt64? + + def initialize(upload) + @tempfile = File.tempfile + ::File.open(@tempfile.path, "w") do |file| + IO.copy(upload.body, file) + end + @filename = upload.filename + @headers = upload.headers + @creation_time = upload.creation_time + @modification_time = upload.modification_time + @read_time = upload.read_time + @size = upload.size + end + end +end diff --git a/lib/kemal/src/kemal/filter_handler.cr b/lib/kemal/src/kemal/filter_handler.cr new file mode 100644 index 00000000..6d28680a --- /dev/null +++ b/lib/kemal/src/kemal/filter_handler.cr @@ -0,0 +1,90 @@ +module Kemal + # :nodoc: + class FilterHandler + include HTTP::Handler + INSTANCE = new + + # This middleware is lazily instantiated and added to the handlers as soon as a call to `after_X` or `before_X` is made. + def initialize + @tree = Radix::Tree(Array(FilterBlock)).new + Kemal.config.add_filter_handler(self) + end + + # The call order of the filters is `before_all -> before_x -> X -> after_x -> after_all`. + def call(context : HTTP::Server::Context) + return call_next(context) unless context.route_found? + call_block_for_path_type("ALL", context.request.path, :before, context) + call_block_for_path_type(context.request.method, context.request.path, :before, context) + if Kemal.config.error_handlers.has_key?(context.response.status_code) + raise Kemal::Exceptions::CustomException.new(context) + end + call_next(context) + call_block_for_path_type(context.request.method, context.request.path, :after, context) + call_block_for_path_type("ALL", context.request.path, :after, context) + context + end + + # :nodoc: This shouldn't be called directly, it's not private because + # I need to call it for testing purpose since I can't call the macros in the spec. + # It adds the block for the corresponding verb/path/type combination to the tree. + def _add_route_filter(verb : String, path, type, &block : HTTP::Server::Context -> _) + lookup = lookup_filters_for_path_type(verb, path, type) + if lookup.found? && lookup.payload.is_a?(Array(FilterBlock)) + lookup.payload << FilterBlock.new(&block) + else + @tree.add radix_path(verb, path, type), [FilterBlock.new(&block)] + end + end + + # This can be called directly but it's simpler to just use the macros, + # it will check if another filter is not already defined for this + # verb/path/type and proceed to call `add_route_filter` + def before(verb : String, path : String = "*", &block : HTTP::Server::Context -> _) + _add_route_filter verb, path, :before, &block + end + + # This can be called directly but it's simpler to just use the macros, + # it will check if another filter is not already defined for this + # verb/path/type and proceed to call `add_route_filter` + def after(verb : String, path : String = "*", &block : HTTP::Server::Context -> _) + _add_route_filter verb, path, :after, &block + end + + # This will fetch the block for the verb/path/type from the tree and call it. + private def call_block_for_path_type(verb : String?, path : String, type, context : HTTP::Server::Context) + lookup = lookup_filters_for_path_type(verb, path, type) + if lookup.found? && lookup.payload.is_a? Array(FilterBlock) + blocks = lookup.payload + blocks.each &.call(context) + end + end + + # This checks is filter is already defined for the verb/path/type combination + private def filter_for_path_type_defined?(verb : String, path : String, type) + lookup = @tree.find radix_path(verb, path, type) + lookup.found? && lookup.payload.is_a? FilterBlock + end + + # This returns a lookup for verb/path/type + private def lookup_filters_for_path_type(verb : String?, path : String, type) + @tree.find radix_path(verb, path, type) + end + + private def radix_path(verb : String?, path : String, type : Symbol) + "#{type}/#{verb}/#{path}" + end + + # :nodoc: + class FilterBlock + property block : HTTP::Server::Context -> String + + def initialize(&block : HTTP::Server::Context -> _) + @block = ->(context : HTTP::Server::Context) { block.call(context).to_s } + end + + def call(context : HTTP::Server::Context) + @block.call(context) + end + end + end +end diff --git a/lib/kemal/src/kemal/handler.cr b/lib/kemal/src/kemal/handler.cr new file mode 100644 index 00000000..bb592381 --- /dev/null +++ b/lib/kemal/src/kemal/handler.cr @@ -0,0 +1,80 @@ +module Kemal + # `Kemal::Handler` is a subclass of `HTTP::Handler`. + # + # It adds `only`, `only_match?`, `exclude`, `exclude_match?`. + # These methods are useful for the conditional execution of custom handlers . + class Handler + include HTTP::Handler + + @@only_routes_tree = Radix::Tree(String).new + @@exclude_routes_tree = Radix::Tree(String).new + + macro only(paths, method = "GET") + class_name = {{@type.name}} + method_downcase = {{method.downcase}} + class_name_method = "#{class_name}/#{method_downcase}" + ({{paths}}).each do |path| + @@only_routes_tree.add class_name_method + path, '/' + method_downcase + path + end + end + + macro exclude(paths, method = "GET") + class_name = {{@type.name}} + method_downcase = {{method.downcase}} + class_name_method = "#{class_name}/#{method_downcase}" + ({{paths}}).each do |path| + @@exclude_routes_tree.add class_name_method + path, '/' + method_downcase + path + end + end + + def call(env : HTTP::Server::Context) + call_next(env) + end + + # Processes the path based on `only` paths which is a `Array(String)`. + # If the path is not found on `only` conditions the handler will continue processing. + # If the path is found in `only` conditions it'll stop processing and will pass the request + # to next handler. + # + # However this is not done automatically. All handlers must inherit from `Kemal::Handler`. + # + # ``` + # class OnlyHandler < Kemal::Handler + # only ["/"] + # + # def call(env) + # return call_next(env) unless only_match?(env) + # puts "If the path is / i will be doing some processing here." + # end + # end + # ``` + def only_match?(env : HTTP::Server::Context) + @@only_routes_tree.find(radix_path(env.request.method, env.request.path)).found? + end + + # Processes the path based on `exclude` paths which is a `Array(String)`. + # If the path is not found on `exclude` conditions the handler will continue processing. + # If the path is found in `exclude` conditions it'll stop processing and will pass the request + # to next handler. + # + # However this is not done automatically. All handlers must inherit from `Kemal::Handler`. + # + # ``` + # class ExcludeHandler < Kemal::Handler + # exclude ["/"] + # + # def call(env) + # return call_next(env) if exclude_match?(env) + # puts "If the path is not / i will be doing some processing here." + # end + # end + # ``` + def exclude_match?(env : HTTP::Server::Context) + @@exclude_routes_tree.find(radix_path(env.request.method, env.request.path)).found? + end + + private def radix_path(method : String, path : String) + "#{self.class}/#{method.downcase}#{path}" + end + end +end diff --git a/lib/kemal/src/kemal/helpers/exception_page.cr b/lib/kemal/src/kemal/helpers/exception_page.cr new file mode 100644 index 00000000..4ec180c9 --- /dev/null +++ b/lib/kemal/src/kemal/helpers/exception_page.cr @@ -0,0 +1,38 @@ +require "exception_page" + +module Kemal + class ExceptionPage < ExceptionPage + def styles : ExceptionPage::Styles + ExceptionPage::Styles.new( + accent: "red", + logo_uri: "" + ) + end + + def project_url + "https://kemalcr.com/" + end + + def self.for_production_exception + <<-HTML + + + + + + +

Kemal has encountered an error. (500)

+

Something wrong with the server :(

+ + + HTML + end + end +end diff --git a/lib/kemal/src/kemal/helpers/exceptions.cr b/lib/kemal/src/kemal/helpers/exceptions.cr new file mode 100644 index 00000000..53520b3a --- /dev/null +++ b/lib/kemal/src/kemal/helpers/exceptions.cr @@ -0,0 +1,20 @@ +# Exceptions for 404 and custom errors are defined here. +module Kemal::Exceptions + class InvalidPathStartException < Exception + def initialize(method : String, path : String) + super "Route declaration #{method} \"#{path}\" needs to start with '/', should be #{method} \"/#{path}\"" + end + end + + class RouteNotFound < Exception + def initialize(context : HTTP::Server::Context) + super "Requested path: '#{context.request.method}:#{context.request.path}' was not found." + end + end + + class CustomException < Exception + def initialize(context : HTTP::Server::Context) + super "Rendered error with #{context.response.status_code}" + end + end +end diff --git a/lib/kemal/src/kemal/helpers/helpers.cr b/lib/kemal/src/kemal/helpers/helpers.cr new file mode 100644 index 00000000..7baf8510 --- /dev/null +++ b/lib/kemal/src/kemal/helpers/helpers.cr @@ -0,0 +1,254 @@ +require "mime" + +# Adds given `Kemal::Handler` to handlers chain. +# There are 5 handlers by default and all the custom handlers +# goes between the first 4 and the last `Kemal::RouteHandler`. +# +# - `Kemal::InitHandler` +# - `Kemal::LogHandler` +# - `Kemal::ExceptionHandler` +# - `Kemal::StaticFileHandler` +# - Here goes custom handlers +# - `Kemal::RouteHandler` +def add_handler(handler : HTTP::Handler) + Kemal.config.add_handler handler +end + +def add_handler(handler : HTTP::Handler, position : Int32) + Kemal.config.add_handler handler, position +end + +# Sets public folder from which the static assets will be served. +# +# By default this is `/public` not `src/public`. +def public_folder(path : String) + Kemal.config.public_folder = path +end + +# Logs the output via `logger`. +# This is the built-in `Kemal::LogHandler` by default which uses STDOUT. +def log(message : String) + Kemal.config.logger.write "#{message}\n" +end + +# Enables / Disables logging. +# This is enabled by default. +# +# ``` +# logging false +# ``` +def logging(status : Bool) + Kemal.config.logging = status +end + +# This is used to replace the built-in `Kemal::LogHandler` with a custom logger. +# +# A custom logger must inherit from `Kemal::BaseLogHandler` and must implement +# `call(env)`, `write(message)` methods. +# +# ``` +# class MyCustomLogger < Kemal::BaseLogHandler +# def call(env) +# puts "I'm logging some custom stuff here." +# call_next(env) # => This calls the next handler +# end +# +# # This is used from `log` method. +# def write(message) +# STDERR.puts message # => Logs the output to STDERR +# end +# end +# ``` +# +# Now that we have a custom logger here's how we use it +# +# ``` +# logger MyCustomLogger.new +# ``` +def logger(logger : Kemal::BaseLogHandler) + Kemal.config.logger = logger + Kemal.config.add_handler logger +end + +# Enables / Disables static file serving. +# This is enabled by default. +# +# ``` +# serve_static false +# ``` +# +# Static server also have some advanced customization options like `dir_listing` and +# `gzip`. +# +# ``` +# serve_static {"gzip" => true, "dir_listing" => false} +# ``` +def serve_static(status : (Bool | Hash)) + Kemal.config.serve_static = status +end + +# Helper for easily modifying response headers. +# This can be used to modify a response header with the given hash. +# +# ``` +# def call(env) +# headers(env, {"X-Custom-Header" => "This is a custom value"}) +# end +# ``` +def headers(env : HTTP::Server::Context, additional_headers : Hash(String, String)) + env.response.headers.merge!(additional_headers) +end + +# Send a file with given path and base the mime-type on the file extension +# or default `application/octet-stream` mime_type. +# +# ``` +# send_file env, "./path/to/file" +# ``` +# +# Optionally you can override the mime_type +# +# ``` +# send_file env, "./path/to/file", "image/jpeg" +# ``` +# +# Also you can set the filename and the disposition +# +# ``` +# send_file env, "./path/to/file", filename: "image.jpg", disposition: "attachment" +# ``` +def send_file(env : HTTP::Server::Context, path : String, mime_type : String? = nil, *, filename : String? = nil, disposition : String? = nil) + config = Kemal.config.serve_static + file_path = File.expand_path(path, Dir.current) + mime_type ||= MIME.from_filename(file_path, "application/octet-stream") + env.response.content_type = mime_type + env.response.headers["Accept-Ranges"] = "bytes" + env.response.headers["X-Content-Type-Options"] = "nosniff" + minsize = 860 # http://webmasters.stackexchange.com/questions/31750/what-is-recommended-minimum-object-size-for-gzip-performance-benefits ?? + request_headers = env.request.headers + filesize = File.size(file_path) + filestat = File.info(file_path) + attachment(env, filename, disposition) + + Kemal.config.static_headers.try(&.call(env.response, file_path, filestat)) + + File.open(file_path) do |file| + if env.request.method == "GET" && env.request.headers.has_key?("Range") + next multipart(file, env) + end + + condition = config.is_a?(Hash) && config["gzip"]? == true && filesize > minsize && Kemal::Utils.zip_types(file_path) + if condition && request_headers.includes_word?("Accept-Encoding", "gzip") + env.response.headers["Content-Encoding"] = "gzip" + Gzip::Writer.open(env.response) do |deflate| + IO.copy(file, deflate) + end + elsif condition && request_headers.includes_word?("Accept-Encoding", "deflate") + env.response.headers["Content-Encoding"] = "deflate" + Flate::Writer.open(env.response) do |deflate| + IO.copy(file, deflate) + end + else + env.response.content_length = filesize + IO.copy(file, env.response) + end + end + return +end + +# Send a file with given data and default `application/octet-stream` mime_type. +# +# ``` +# send_file env, data_slice +# ``` +# +# Optionally you can override the mime_type +# +# ``` +# send_file env, data_slice, "image/jpeg" +# ``` +# +# Also you can set the filename and the disposition +# +# ``` +# send_file env, data_slice, filename: "image.jpg", disposition: "attachment" +# ``` +def send_file(env : HTTP::Server::Context, data : Slice(UInt8), mime_type : String? = nil, *, filename : String? = nil, disposition : String? = nil) + mime_type ||= "application/octet-stream" + env.response.content_type = mime_type + env.response.content_length = data.bytesize + attachment(env, filename, disposition) + env.response.write data +end + +private def multipart(file, env : HTTP::Server::Context) + # See http://httpwg.org/specs/rfc7233.html + fileb = file.size + startb = endb = 0_i64 + + if match = env.request.headers["Range"].match /bytes=(\d{1,})-(\d{0,})/ + startb = match[1].to_i64 { 0_i64 } if match.size >= 2 + endb = match[2].to_i64 { 0_i64 } if match.size >= 3 + end + + endb = fileb - 1 if endb == 0 + + if startb < endb < fileb + content_length = 1_i64 + endb - startb + env.response.status_code = 206 + env.response.content_length = content_length + env.response.headers["Accept-Ranges"] = "bytes" + env.response.headers["Content-Range"] = "bytes #{startb}-#{endb}/#{fileb}" # MUST + + if startb > 1024 + skipped = 0_i64 + # file.skip only accepts values less or equal to 1024 (buffer size, undocumented) + until (increase_skipped = skipped + 1024_i64) > startb + file.skip(1024) + skipped = increase_skipped + end + if (skipped_minus_startb = skipped - startb) > 0 + file.skip skipped_minus_startb + end + else + file.skip(startb) + end + + IO.copy(file, env.response, content_length) + else + env.response.content_length = fileb + env.response.status_code = 200 # Range not satisfable, see 4.4 Note + IO.copy(file, env.response) + end +end + +# Set the Content-Disposition to "attachment" with the specified filename, +# instructing the user agents to prompt to save. +private def attachment(env : HTTP::Server::Context, filename : String? = nil, disposition : String? = nil) + disposition = "attachment" if disposition.nil? && filename + if disposition && filename + env.response.headers["Content-Disposition"] = "#{disposition}; filename=\"#{File.basename(filename)}\"" + end +end + +# Configures an `HTTP::Server::Response` to compress the response +# output, either using gzip or deflate, depending on the `Accept-Encoding` request header. +# +# Disabled by default. +def gzip(status : Bool = false) + add_handler HTTP::CompressHandler.new if status +end + +# Adds headers to `Kemal::StaticFileHandler`. This is especially useful for `CORS`. +# +# ``` +# static_headers do |response, filepath, filestat| +# if filepath =~ /\.html$/ +# response.headers.add("Access-Control-Allow-Origin", "*") +# end +# response.headers.add("Content-Size", filestat.size.to_s) +# end +# ``` +def static_headers(&headers : HTTP::Server::Response, String, File::Info -> Void) + Kemal.config.static_headers = headers +end diff --git a/lib/kemal/src/kemal/helpers/macros.cr b/lib/kemal/src/kemal/helpers/macros.cr new file mode 100644 index 00000000..4b5e3090 --- /dev/null +++ b/lib/kemal/src/kemal/helpers/macros.cr @@ -0,0 +1,98 @@ +require "kilt" + +CONTENT_FOR_BLOCKS = Hash(String, Tuple(String, Proc(String))).new + +# `content_for` is a set of helpers that allows you to capture +# blocks inside views to be rendered later during the request. The most +# common use is to populate different parts of your layout from your view. +# +# The currently supported engines are: ecr and slang. +# +# ## Usage +# +# You call `content_for`, generally from a view, to capture a block of markup +# giving it an identifier: +# +# ``` +# # index.ecr +# <% content_for "some_key" do %> +# ... +# <% end %> +# ``` +# +# Then, you call `yield_content` with that identifier, generally from a +# layout, to render the captured block: +# +# ``` +# # layout.ecr +# <%= yield_content "some_key" %> +# ``` +# +# ## And How Is This Useful? +# +# For example, some of your views might need a few javascript tags and +# stylesheets, but you don't want to force this files in all your pages. +# Then you can put `<%= yield_content :scripts_and_styles %>` on your +# layout, inside the tag, and each view can call `content_for` +# setting the appropriate set of tags that should be added to the layout. +macro content_for(key, file = __FILE__) + %proc = ->() { + __kilt_io__ = IO::Memory.new + {{ yield }} + __kilt_io__.to_s + } + + CONTENT_FOR_BLOCKS[{{key}}] = Tuple.new {{file}}, %proc + nil +end + +# Yields content for the given key if a `content_for` block exists for that key. +macro yield_content(key) + if CONTENT_FOR_BLOCKS.has_key?({{key}}) + __caller_filename__ = CONTENT_FOR_BLOCKS[{{key}}][0] + %proc = CONTENT_FOR_BLOCKS[{{key}}][1] + %proc.call if __content_filename__ == __caller_filename__ + end +end + +# Render view with a layout as the superview. +# +# ``` +# render "src/views/index.ecr", "src/views/layout.ecr" +# ``` +macro render(filename, layout) + __content_filename__ = {{filename}} + content = render {{filename}} + render {{layout}} +end + +# Render view with the given filename. +macro render(filename) + Kilt.render({{filename}}) +end + +# Halt execution with the current context. +# Returns 200 and an empty response by default. +# +# ``` +# halt env, status_code: 403, response: "Forbidden" +# ``` +macro halt(env, status_code = 200, response = "") + {{env}}.response.status_code = {{status_code}} + {{env}}.response.print {{response}} + {{env}}.response.close + next +end + +# Extends context storage with user defined types. +# +# ``` +# class User +# property name +# end +# +# add_context_storage_type(User) +# ``` +macro add_context_storage_type(type) + {{ HTTP::Server::Context::STORE_MAPPINGS.push(type) }} +end diff --git a/lib/kemal/src/kemal/helpers/templates.cr b/lib/kemal/src/kemal/helpers/templates.cr new file mode 100644 index 00000000..b343fc8a --- /dev/null +++ b/lib/kemal/src/kemal/helpers/templates.cr @@ -0,0 +1,35 @@ +# This file contains the built-in view templates that Kemal uses. +# Currently it contains templates for 404 and 500 error codes. + +def render_404 + <<-HTML + + + + + + +

Kemal doesn't know this way.

+ + + + HTML +end + +def render_500(context, exception, verbosity) + context.response.status_code = 500 + + template = if verbosity + Kemal::ExceptionPage.for_runtime_exception(context, exception).to_s + else + Kemal::ExceptionPage.for_production_exception + end + + context.response.print template + context +end diff --git a/lib/kemal/src/kemal/helpers/utils.cr b/lib/kemal/src/kemal/helpers/utils.cr new file mode 100644 index 00000000..3ece5d44 --- /dev/null +++ b/lib/kemal/src/kemal/helpers/utils.cr @@ -0,0 +1,13 @@ +module Kemal + module Utils + ZIP_TYPES = {".htm", ".html", ".txt", ".css", ".js", ".svg", ".json", ".xml", ".otf", ".ttf", ".woff", ".woff2"} + + def self.path_starts_with_slash?(path : String) + path.starts_with? '/' + end + + def self.zip_types(path : String) # https://github.com/h5bp/server-configs-nginx/blob/master/nginx.conf + ZIP_TYPES.includes? File.extname(path) + end + end +end diff --git a/lib/kemal/src/kemal/init_handler.cr b/lib/kemal/src/kemal/init_handler.cr new file mode 100644 index 00000000..881325b6 --- /dev/null +++ b/lib/kemal/src/kemal/init_handler.cr @@ -0,0 +1,15 @@ +module Kemal + # Initializes the context with default values, such as + # *Content-Type* or *X-Powered-By* headers. + class InitHandler + include HTTP::Handler + + INSTANCE = new + + def call(context : HTTP::Server::Context) + context.response.headers.add "X-Powered-By", "Kemal" if Kemal.config.powered_by_header + context.response.content_type = "text/html" unless context.response.headers.has_key?("Content-Type") + call_next context + end + end +end diff --git a/lib/kemal/src/kemal/log_handler.cr b/lib/kemal/src/kemal/log_handler.cr new file mode 100644 index 00000000..fe902f28 --- /dev/null +++ b/lib/kemal/src/kemal/log_handler.cr @@ -0,0 +1,25 @@ +module Kemal + # Uses `STDOUT` by default and handles the logging of request/response process time. + class LogHandler < Kemal::BaseLogHandler + def initialize(@io : IO = STDOUT) + end + + def call(context : HTTP::Server::Context) + elapsed_time = Time.measure { call_next(context) } + elapsed_text = elapsed_text(elapsed_time) + @io << Time.utc << ' ' << context.response.status_code << ' ' << context.request.method << ' ' << context.request.resource << ' ' << elapsed_text << '\n' + context + end + + def write(message : String) + @io << message + end + + private def elapsed_text(elapsed) + millis = elapsed.total_milliseconds + return "#{millis.round(2)}ms" if millis >= 1 + + "#{(millis * 1000).round(2)}µs" + end + end +end diff --git a/lib/kemal/src/kemal/null_log_handler.cr b/lib/kemal/src/kemal/null_log_handler.cr new file mode 100644 index 00000000..9f3e03ab --- /dev/null +++ b/lib/kemal/src/kemal/null_log_handler.cr @@ -0,0 +1,11 @@ +module Kemal + # This is here to represent the logger corresponding to Null Object Pattern. + class NullLogHandler < Kemal::BaseLogHandler + def call(context : HTTP::Server::Context) + call_next(context) + end + + def write(message : String) + end + end +end diff --git a/lib/kemal/src/kemal/param_parser.cr b/lib/kemal/src/kemal/param_parser.cr new file mode 100644 index 00000000..92c31123 --- /dev/null +++ b/lib/kemal/src/kemal/param_parser.cr @@ -0,0 +1,111 @@ +module Kemal + # Parses the request contents including query_params and body + # and converts them into a params hash which you can use within + # the environment context. + class ParamParser + URL_ENCODED_FORM = "application/x-www-form-urlencoded" + APPLICATION_JSON = "application/json" + MULTIPART_FORM = "multipart/form-data" + PARTS = %w(url query body json files) + # :nodoc: + alias AllParamTypes = Nil | String | Int64 | Float64 | Bool | Hash(String, JSON::Any) | Array(JSON::Any) + getter files + + def initialize(@request : HTTP::Request, @url : Hash(String, String) = {} of String => String) + @query = HTTP::Params.new({} of String => Array(String)) + @body = HTTP::Params.new({} of String => Array(String)) + @json = {} of String => AllParamTypes + @files = {} of String => FileUpload + @url_parsed = false + @query_parsed = false + @body_parsed = false + @json_parsed = false + @files_parsed = false + end + + private def unescape_url_param(value : String) + value.empty? ? value : URI.decode(value) + rescue + value + end + + {% for method in PARTS %} + def {{method.id}} + # check memoization + return @{{method.id}} if @{{method.id}}_parsed + + parse_{{method.id}} + # memoize + @{{method.id}}_parsed = true + @{{method.id}} + end + {% end %} + + private def parse_body + content_type = @request.headers["Content-Type"]? + + return unless content_type + + if content_type.try(&.starts_with?(URL_ENCODED_FORM)) + @body = parse_part(@request.body) + return + end + + if content_type.try(&.starts_with?(MULTIPART_FORM)) + parse_files + end + end + + private def parse_query + @query = parse_part(@request.query) + end + + private def parse_url + @url.each { |key, value| @url[key] = unescape_url_param(value) } + end + + private def parse_files + return if @files_parsed + + HTTP::FormData.parse(@request) do |upload| + next unless upload + + filename = upload.filename + + if !filename.nil? + @files[upload.name] = FileUpload.new(upload) + else + @body.add(upload.name, upload.body.gets_to_end) + end + end + + @files_parsed = true + end + + # Parses JSON request body if Content-Type is `application/json`. + # + # - If request body is a JSON `Hash` then all the params are parsed and added into `params`. + # - If request body is a JSON `Array` it's added into `params` as `_json` and can be accessed like `params["_json"]`. + private def parse_json + return unless @request.body && @request.headers["Content-Type"]?.try(&.starts_with?(APPLICATION_JSON)) + + body = @request.body.not_nil!.gets_to_end + case json = JSON.parse(body).raw + when Hash + json.each do |key, value| + @json[key] = value.raw + end + when Array + @json["_json"] = json + end + end + + private def parse_part(part : IO?) + HTTP::Params.parse(part ? part.gets_to_end : "") + end + + private def parse_part(part : String?) + HTTP::Params.parse part.to_s + end + end +end diff --git a/lib/kemal/src/kemal/route.cr b/lib/kemal/src/kemal/route.cr new file mode 100644 index 00000000..cfcf29ec --- /dev/null +++ b/lib/kemal/src/kemal/route.cr @@ -0,0 +1,17 @@ +module Kemal + # Route is the main building block of Kemal. + # + # It takes 3 parameters: http *method*, *path* and a *handler* to specify + # what action to be done if the route is matched. + struct Route + getter method, path, handler + @handler : HTTP::Server::Context -> String + + def initialize(@method : String, @path : String, &handler : HTTP::Server::Context -> _) + @handler = ->(context : HTTP::Server::Context) do + output = handler.call(context) + output.is_a?(String) ? output : "" + end + end + end +end diff --git a/lib/kemal/src/kemal/route_handler.cr b/lib/kemal/src/kemal/route_handler.cr new file mode 100644 index 00000000..528d7736 --- /dev/null +++ b/lib/kemal/src/kemal/route_handler.cr @@ -0,0 +1,67 @@ +require "radix" + +module Kemal + class RouteHandler + include HTTP::Handler + + INSTANCE = new + CACHED_ROUTES_LIMIT = 1024 + property routes, cached_routes + + def initialize + @routes = Radix::Tree(Route).new + @cached_routes = Hash(String, Radix::Result(Route)).new + end + + def call(context : HTTP::Server::Context) + process_request(context) + end + + # Adds a given route to routing tree. As an exception each `GET` route additionaly defines + # a corresponding `HEAD` route. + def add_route(method : String, path : String, &handler : HTTP::Server::Context -> _) + add_to_radix_tree method, path, Route.new(method, path, &handler) + add_to_radix_tree("HEAD", path, Route.new("HEAD", path) { }) if method == "GET" + end + + # Looks up the route from the Radix::Tree for the first time and caches to improve performance. + def lookup_route(verb : String, path : String) + lookup_path = radix_path(verb, path) + + if cached_route = @cached_routes[lookup_path]? + return cached_route + end + + route = @routes.find(lookup_path) + + if route.found? + @cached_routes.clear if @cached_routes.size == CACHED_ROUTES_LIMIT + @cached_routes[lookup_path] = route + end + + route + end + + # Processes the route if it's a match. Otherwise renders 404. + private def process_request(context) + raise Kemal::Exceptions::RouteNotFound.new(context) unless context.route_found? + content = context.route.handler.call(context) + + if !Kemal.config.error_handlers.empty? && Kemal.config.error_handlers.has_key?(context.response.status_code) + raise Kemal::Exceptions::CustomException.new(context) + end + + context.response.print(content) + context + end + + private def radix_path(method, path) + '/' + method.downcase + path + end + + private def add_to_radix_tree(method, path, route) + node = radix_path method, path + @routes.add node, route + end + end +end diff --git a/lib/kemal/src/kemal/ssl.cr b/lib/kemal/src/kemal/ssl.cr new file mode 100644 index 00000000..e205b18b --- /dev/null +++ b/lib/kemal/src/kemal/ssl.cr @@ -0,0 +1,17 @@ +module Kemal + class SSL + getter context + + def initialize + @context = OpenSSL::SSL::Context::Server.new + end + + def key_file=(key_file : String) + @context.private_key = key_file + end + + def cert_file=(cert_file : String) + @context.certificate_chain = cert_file + end + end +end diff --git a/lib/kemal/src/kemal/static_file_handler.cr b/lib/kemal/src/kemal/static_file_handler.cr new file mode 100644 index 00000000..cd837496 --- /dev/null +++ b/lib/kemal/src/kemal/static_file_handler.cr @@ -0,0 +1,71 @@ +{% if !flag?(:without_zlib) %} + require "zlib" +{% end %} + +module Kemal + class StaticFileHandler < HTTP::StaticFileHandler + def call(context : HTTP::Server::Context) + return call_next(context) if context.request.path.not_nil! == "/" + + case context.request.method + when "GET", "HEAD" + else + if @fallthrough + call_next(context) + else + context.response.status_code = 405 + context.response.headers.add("Allow", "GET, HEAD") + end + return + end + + config = Kemal.config.serve_static + original_path = context.request.path.not_nil! + request_path = URI.decode(original_path) + + # File path cannot contains '\0' (NUL) because all filesystem I know + # don't accept '\0' character as file name. + if request_path.includes? '\0' + context.response.status_code = 400 + return + end + + expanded_path = File.expand_path(request_path, "/") + is_dir_path = if original_path.ends_with?('/') && !expanded_path.ends_with? '/' + expanded_path = expanded_path + '/' + true + else + expanded_path.ends_with? '/' + end + + file_path = File.join(@public_dir, expanded_path) + is_dir = Dir.exists? file_path + + if request_path != expanded_path + redirect_to context, expanded_path + elsif is_dir && !is_dir_path + redirect_to context, expanded_path + '/' + end + + if Dir.exists?(file_path) + if config.is_a?(Hash) && config["dir_listing"] == true + context.response.content_type = "text/html" + directory_listing(context.response, request_path, file_path) + else + call_next(context) + end + elsif File.exists?(file_path) + last_modified = modification_time(file_path) + add_cache_headers(context.response.headers, last_modified) + + if cache_request?(context, last_modified) + context.response.status_code = 304 + return + end + send_file(context, file_path) + else + call_next(context) + end + end + end +end diff --git a/lib/kemal/src/kemal/websocket.cr b/lib/kemal/src/kemal/websocket.cr new file mode 100644 index 00000000..2b65f8ec --- /dev/null +++ b/lib/kemal/src/kemal/websocket.cr @@ -0,0 +1,14 @@ +module Kemal + # Takes 2 parameters: *path* and a *handler* to specify + # what action to be done if the route is matched. + class WebSocket < HTTP::WebSocketHandler + getter proc + + def initialize(@path : String, &@proc : HTTP::WebSocket, HTTP::Server::Context -> Void) + end + + def call(context : HTTP::Server::Context) + super + end + end +end diff --git a/lib/kemal/src/kemal/websocket_handler.cr b/lib/kemal/src/kemal/websocket_handler.cr new file mode 100644 index 00000000..addbecfa --- /dev/null +++ b/lib/kemal/src/kemal/websocket_handler.cr @@ -0,0 +1,43 @@ +module Kemal + class WebSocketHandler + include HTTP::Handler + + INSTANCE = new + property routes + + def initialize + @routes = Radix::Tree(WebSocket).new + end + + def call(context : HTTP::Server::Context) + return call_next(context) unless context.ws_route_found? && websocket_upgrade_request?(context) + content = context.websocket.call(context) + context.response.print(content) + context + end + + def lookup_ws_route(path : String) + @routes.find "/ws" + path + end + + def add_route(path : String, &handler : HTTP::WebSocket, HTTP::Server::Context -> Void) + add_to_radix_tree path, WebSocket.new(path, &handler) + end + + private def add_to_radix_tree(path, websocket) + node = radix_path "ws", path + @routes.add node, websocket + end + + private def radix_path(method, path) + '/' + method.downcase + path + end + + private def websocket_upgrade_request?(context) + return unless upgrade = context.request.headers["Upgrade"]? + return unless upgrade.compare("websocket", case_insensitive: true) == 0 + + context.request.headers.includes_word?("Connection", "Upgrade") + end + end +end diff --git a/lib/kilt/.gitignore b/lib/kilt/.gitignore new file mode 100644 index 00000000..2e5d65e9 --- /dev/null +++ b/lib/kilt/.gitignore @@ -0,0 +1,10 @@ +/doc/ +/lib/ +/.crystal/ +/.shards/ + + +# Libraries don't need dependency lock +# Dependencies will be locked in application that uses them +/shard.lock + diff --git a/lib/kilt/.travis.yml b/lib/kilt/.travis.yml new file mode 100644 index 00000000..ffc7b6ac --- /dev/null +++ b/lib/kilt/.travis.yml @@ -0,0 +1 @@ +language: crystal diff --git a/lib/kilt/LICENSE b/lib/kilt/LICENSE new file mode 100644 index 00000000..f7175eff --- /dev/null +++ b/lib/kilt/LICENSE @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2016 Jerome Gravel-Niquet + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/lib/kilt/README.md b/lib/kilt/README.md new file mode 100644 index 00000000..43272988 --- /dev/null +++ b/lib/kilt/README.md @@ -0,0 +1,105 @@ +# Kilt [![Build Status](https://travis-ci.org/jeromegn/kilt.svg?branch=master)](https://travis-ci.org/jeromegn/kilt) [![Dependency Status](https://shards.rocks/badge/github/jeromegn/kilt/status.svg)](https://shards.rocks/github/jeromegn/kilt) [![devDependency Status](https://shards.rocks/badge/github/jeromegn/kilt/dev_status.svg)](https://shards.rocks/github/jeromegn/kilt) + +Generic templating interface for Crystal. + +## Goal + +Simplify developers' lives by abstracting template rendering for multiple template languages. + +## Supported out of the box + +| Language | File extensions | Required libraries | Maintainer | +| -------- | --------------- | ------------------ | ---------- | +| ECR | .ecr | none (part of the stdlib) | | +| Mustache | .mustache | [crustache](https://github.com/MakeNowJust/crustache) | [@MakeNowJust](https://github.com/MakeNowJust) | +| Slang | .slang | [slang](https://github.com/jeromegn/slang) | [@jeromegn](https://github.com/jeromegn) | +| Temel | .temel | [temel](https://github.com/f/temel) | [@f](https://github.com/f) | +| Crikey | .crikey | [crikey](https://github.com/domgetter/crikey) | [@domgetter](https://github.com/domgetter) | + +See also: +[Registering your own template engine](#registering-your-own-template-engine). + +## Installation + +Add this to your application's `shard.yml`: + +```yaml +dependencies: + kilt: + github: jeromegn/kilt + + # Any other template languages Crystal shard +``` + +## Usage + +- Kilt essentially adds two macros `Kilt.embed` and `Kilt.file`, the code is really simple. +- Add template language dependencies, as listed in the support table above. + +Both macros take a `filename` and a `io_name` (the latter defaults to `"__kilt_io__"`) + +### Example + +```crystal +require "kilt" + +# For slang, add: +require "kilt/slang" + +# With a Class + +class YourView + Kilt.file("path/to/template.ecr") # Adds a to_s method +end +puts YourView.new.to_s # => + + +# Embedded + +str = Kilt.render "path/to/template.slang" + +# or + +str = String.build do |__kilt_io__| + Kilt.embed "path/to/template.slang" +end + +puts str # => +``` + +## Registering your own template engine + +Use `Kilt.register_engine(extension, embed_command)` macro: + +```crystal +require "kilt" + +module MyEngine + macro embed(filename, io_name) + # .... + end +end + +Kilt.register_engine("myeng", MyEngine.embed) +``` + +This can be part of your own `my-engine` library: in this case it should depend +on `kilt` directly, or this could be a part of adapter library, like: +`kilt-my-engine`, which will depend on both `kilt` and `my-engine`. + +## Contributing + +Please contribute your own "adapter" if you create a template language for Crystal that's not yet supported here! + +1. Fork it ( https://github.com/jeromegn/kilt/fork ) +2. Create your feature branch (git checkout -b my-awesome-template-language) +3. Commit your changes (git commit -am 'Add my-awesome-template-language') +4. Push to the branch (git push origin my-awesome-template-language) +5. Create a new Pull Request + +## Contributors + +- [jeromegn](https://github.com/jeromegn) Jerome Gravel-Niquet - creator, maintainer +- [waterlink](https://github.com/waterlink) Oleksii Fedorov +- [MakeNowJust](https://github.com/MakeNowJust) TSUYUSATO Kitsune +- [f](https://github.com/f) Fatih Kadir Akın diff --git a/lib/kilt/shard.yml b/lib/kilt/shard.yml new file mode 100644 index 00000000..562cb39c --- /dev/null +++ b/lib/kilt/shard.yml @@ -0,0 +1,15 @@ +name: kilt +version: 0.4.0 + +authors: + - Jerome Gravel-Niquet + +license: MIT + +development_dependencies: + slang: + github: jeromegn/slang + crustache: + github: MakeNowJust/crustache + temel: + github: f/temel diff --git a/lib/kilt/spec/fixtures/test.ecr b/lib/kilt/spec/fixtures/test.ecr new file mode 100644 index 00000000..7ae32556 --- /dev/null +++ b/lib/kilt/spec/fixtures/test.ecr @@ -0,0 +1 @@ +<%= Process.pid %> \ No newline at end of file diff --git a/lib/kilt/spec/fixtures/test.mustache b/lib/kilt/spec/fixtures/test.mustache new file mode 100644 index 00000000..f31f09dc --- /dev/null +++ b/lib/kilt/spec/fixtures/test.mustache @@ -0,0 +1 @@ +{{pid}} \ No newline at end of file diff --git a/lib/kilt/spec/fixtures/test.raw b/lib/kilt/spec/fixtures/test.raw new file mode 100644 index 00000000..c57eff55 --- /dev/null +++ b/lib/kilt/spec/fixtures/test.raw @@ -0,0 +1 @@ +Hello World! \ No newline at end of file diff --git a/lib/kilt/spec/fixtures/test.slang b/lib/kilt/spec/fixtures/test.slang new file mode 100644 index 00000000..e8325c71 --- /dev/null +++ b/lib/kilt/spec/fixtures/test.slang @@ -0,0 +1 @@ +span = Process.pid \ No newline at end of file diff --git a/lib/kilt/spec/fixtures/test.temel b/lib/kilt/spec/fixtures/test.temel new file mode 100644 index 00000000..b214e1f4 --- /dev/null +++ b/lib/kilt/spec/fixtures/test.temel @@ -0,0 +1 @@ +span Process.pid diff --git a/lib/kilt/spec/kilt/crustache_spec.cr b/lib/kilt/spec/kilt/crustache_spec.cr new file mode 100644 index 00000000..f62b7d26 --- /dev/null +++ b/lib/kilt/spec/kilt/crustache_spec.cr @@ -0,0 +1,26 @@ +require "../spec_helper" +require "../../src/crustache" + +class MustacheView + def has_key?(name) + name == "pid" + end + + def [](name) + name == "pid" ? Process.pid : nil + end + + Kilt.file "spec/fixtures/test.mustache", "__kilt_io__", self +end + +describe "kilt/crustache" do + + it "renders crustache" do + Kilt.render("spec/fixtures/test.mustache", { "pid" => Process.pid }).should eq("#{Process.pid}") + end + + it "works with classes" do + MustacheView.new.to_s.should eq("#{Process.pid}") + end + +end diff --git a/lib/kilt/spec/kilt/slang_spec.cr b/lib/kilt/spec/kilt/slang_spec.cr new file mode 100644 index 00000000..4dbf6612 --- /dev/null +++ b/lib/kilt/spec/kilt/slang_spec.cr @@ -0,0 +1,18 @@ +require "../spec_helper" +require "../../src/slang" + +class SlangView + Kilt.file "spec/fixtures/test.slang" +end + +describe "kilt/slang" do + + it "renders slang" do + Kilt.render("spec/fixtures/test.slang").should eq("#{Process.pid}") + end + + it "works with classes" do + SlangView.new.to_s.should eq("#{Process.pid}") + end + +end diff --git a/lib/kilt/spec/kilt/temel_spec.cr b/lib/kilt/spec/kilt/temel_spec.cr new file mode 100644 index 00000000..aed4b48b --- /dev/null +++ b/lib/kilt/spec/kilt/temel_spec.cr @@ -0,0 +1,18 @@ +# require "../spec_helper" +# require "../../src/temel" + +# class TemelView +# Kilt.file "spec/fixtures/test.temel" +# end + +# describe "kilt/temel" do + +# it "renders temel" do +# Kilt.render("spec/fixtures/test.temel").should eq("#{Process.pid}") +# end + +# it "works with classes" do +# TemelView.new.to_s.should eq("#{Process.pid}") +# end + +# end diff --git a/lib/kilt/spec/kilt_spec.cr b/lib/kilt/spec/kilt_spec.cr new file mode 100644 index 00000000..dc3a3e50 --- /dev/null +++ b/lib/kilt/spec/kilt_spec.cr @@ -0,0 +1,28 @@ +require "./spec_helper" + +class View + Kilt.file "spec/fixtures/test.ecr" +end + +describe Kilt do + + it "renders ecr" do + Kilt.render("spec/fixtures/test.ecr").should eq("#{Process.pid}") + end + + it "works with classes" do + View.new.to_s.should eq("#{Process.pid}") + end + + it "raises with unsupported filetype" do + expect_raises(Kilt::Exception, "Unsupported template engine for extension: \"abc\"") { + Kilt.render("test.abc") + } + end + + it "renders registered engine" do + Kilt.register_engine "raw", Raw.embed + Kilt.render("spec/fixtures/test.raw").should eq("Hello World!") + end + +end diff --git a/lib/kilt/spec/spec_helper.cr b/lib/kilt/spec/spec_helper.cr new file mode 100644 index 00000000..d69d98e3 --- /dev/null +++ b/lib/kilt/spec/spec_helper.cr @@ -0,0 +1,3 @@ +require "spec" +require "../src/kilt" +require "./support/raw_engine" \ No newline at end of file diff --git a/lib/kilt/spec/support/raw_engine.cr b/lib/kilt/spec/support/raw_engine.cr new file mode 100644 index 00000000..2e0f9dc6 --- /dev/null +++ b/lib/kilt/spec/support/raw_engine.cr @@ -0,0 +1,5 @@ +module Raw + macro embed(filename, io) + {{ io.id }} << {{`cat #{filename}`.stringify}} + end +end \ No newline at end of file diff --git a/lib/kilt/src/crikey.cr b/lib/kilt/src/crikey.cr new file mode 100644 index 00000000..f20b9e21 --- /dev/null +++ b/lib/kilt/src/crikey.cr @@ -0,0 +1,4 @@ +require "./kilt" +require "crikey" + +Kilt.register_engine "crikey", Crikey.embed diff --git a/lib/kilt/src/crustache.cr b/lib/kilt/src/crustache.cr new file mode 100644 index 00000000..e766e6b4 --- /dev/null +++ b/lib/kilt/src/crustache.cr @@ -0,0 +1,4 @@ +require "./kilt" +require "crustache" + +Kilt.register_engine "mustache", Mustache.embed \ No newline at end of file diff --git a/lib/kilt/src/ecr.cr b/lib/kilt/src/ecr.cr new file mode 100644 index 00000000..2360c0d4 --- /dev/null +++ b/lib/kilt/src/ecr.cr @@ -0,0 +1,3 @@ +require "ecr/macros" + +Kilt.register_engine("ecr", ECR.embed) diff --git a/lib/kilt/src/kilt.cr b/lib/kilt/src/kilt.cr new file mode 100644 index 00000000..6dfbb0e8 --- /dev/null +++ b/lib/kilt/src/kilt.cr @@ -0,0 +1,35 @@ +require "./kilt/version" +require "./kilt/exception" + +module Kilt + # macro only constant + ENGINES = {} of String => Int32 + + macro register_engine(ext, embed_macro) + {% Kilt::ENGINES[ext] = embed_macro.id %} + end + + macro embed(filename, io_name = "__kilt_io__", *args) + {% ext = filename.split(".").last %} + + {% if Kilt::ENGINES[ext] %} + {{Kilt::ENGINES[ext]}}({{filename}}, {{io_name}}, {{*args}}) + {% else %} + raise Kilt::Exception.new("Unsupported template engine for extension: \"" + {{ext}} + "\"") + {% end %} + end + + macro render(filename, *args) + String.build do |__kilt_io__| + Kilt.embed({{filename}}, "__kilt_io__", {{*args}}) + end + end + + macro file(filename, io_name = "__kilt_io__", *args) + def to_s({{io_name.id}}) + Kilt.embed({{filename}}, {{io_name}}, {{*args}}) + end + end +end + +require "./ecr" diff --git a/lib/kilt/src/kilt/exception.cr b/lib/kilt/src/kilt/exception.cr new file mode 100644 index 00000000..2e43dfd3 --- /dev/null +++ b/lib/kilt/src/kilt/exception.cr @@ -0,0 +1,5 @@ +module Kilt + class Exception < ::Exception + # Nothing special + end +end \ No newline at end of file diff --git a/lib/kilt/src/kilt/helpers/temel_embedder.cr b/lib/kilt/src/kilt/helpers/temel_embedder.cr new file mode 100644 index 00000000..38967c25 --- /dev/null +++ b/lib/kilt/src/kilt/helpers/temel_embedder.cr @@ -0,0 +1,3 @@ +require "temel" + +puts File.read(ARGV[0]).to_s STDOUT diff --git a/lib/kilt/src/kilt/version.cr b/lib/kilt/src/kilt/version.cr new file mode 100644 index 00000000..2153f5ad --- /dev/null +++ b/lib/kilt/src/kilt/version.cr @@ -0,0 +1,3 @@ +module Kilt + VERSION = "0.4.0" +end diff --git a/lib/kilt/src/slang.cr b/lib/kilt/src/slang.cr new file mode 100644 index 00000000..249af603 --- /dev/null +++ b/lib/kilt/src/slang.cr @@ -0,0 +1,4 @@ +require "./kilt" +require "slang" + +Kilt.register_engine "slang", Slang.embed diff --git a/lib/kilt/src/temel.cr b/lib/kilt/src/temel.cr new file mode 100644 index 00000000..13a6423e --- /dev/null +++ b/lib/kilt/src/temel.cr @@ -0,0 +1,9 @@ +require "./kilt" +require "temel" + +macro embed_temel(filename, __kilt_io__) + __kilt_io__ << {{ run("./kilt/helpers/temel_embedder.cr", filename) }} + __kilt_io__ +end + +Kilt.register_engine "temel", embed_temel diff --git a/lib/pg/.circleci/config.yml b/lib/pg/.circleci/config.yml new file mode 100644 index 00000000..c05124d3 --- /dev/null +++ b/lib/pg/.circleci/config.yml @@ -0,0 +1,50 @@ +version: 2 + +testdefault: &testdefault + docker: + - image: crystallang/crystal:latest + steps: + - checkout + - run: + name: postgres setup + command: | + bash .circleci/install.sh + bash .circleci/setup.sh + touch spec/.run_auth_specs + - run: crystal --version + - run: crystal tool format --check + - run: shards install + - run: + name: set DATABASE_URL + command: echo 'export DATABASE_URL=postgres://postgres@localhost/postgres' >> $BASH_ENV + - run: crystal spec --error-on-warnings + +noscram: &noscram + environment: + NOSCRAM: true + +jobs: + postgresql-11: + <<: *testdefault + postgresql-10: + <<: *testdefault + postgresql-9.6: + <<: *testdefault + <<: *noscram + postgresql-9.5: + <<: *testdefault + <<: *noscram + postgresql-9.4: + <<: *testdefault + <<: *noscram + +workflows: + version: 2 + ci: + jobs: + - postgresql-11 + - postgresql-10 + - postgresql-9.6 + - postgresql-9.5 + - postgresql-9.4 + diff --git a/lib/pg/.circleci/install.sh b/lib/pg/.circleci/install.sh new file mode 100644 index 00000000..ea93f870 --- /dev/null +++ b/lib/pg/.circleci/install.sh @@ -0,0 +1,9 @@ +set -e + +apt-get update +apt-get install curl -y +curl https://www.postgresql.org/media/keys/ACCC4CF8.asc | apt-key add - +echo "deb http://apt.postgresql.org/pub/repos/apt/ xenial-pgdg main" > /etc/apt/sources.list.d/pgdg.list +apt-get update +apt-get install $CIRCLE_JOB -y + diff --git a/lib/pg/.circleci/pg_hba.conf b/lib/pg/.circleci/pg_hba.conf new file mode 100644 index 00000000..ef28d4ca --- /dev/null +++ b/lib/pg/.circleci/pg_hba.conf @@ -0,0 +1,5 @@ +# TYPE DATABASE USER ADDRESS METHOD +local all postgres trust +host all postgres 127.0.0.1/32 trust +host all crystal_md5 127.0.0.1/32 md5 +hostssl all crystal_ssl 127.0.0.1/32 cert clientcert=1 diff --git a/lib/pg/.circleci/setup.sh b/lib/pg/.circleci/setup.sh new file mode 100644 index 00000000..1bf1ecf5 --- /dev/null +++ b/lib/pg/.circleci/setup.sh @@ -0,0 +1,41 @@ +set -e + +VER=$(crystal eval 'puts ENV["CIRCLE_JOB"].split("-").last') +CONF=/etc/postgresql/$VER/main +echo "VER=#{$VER}" +echo "CONF=#{$CONF}" + +cp .circleci/pg_hba.conf $CONF +if [ -v NOSCRAM ]; then + echo "not adding scram to pg_hba" +else + echo "host all crystal_scram 127.0.0.1/32 scram-sha-256" >> $CONF/pg_hba.conf +fi + +mkdir .cert +chmod 700 .cert +cd .cert + +openssl req -new -nodes -text -out ca.csr -keyout ca-key.pem -subj "/CN=certificate-authority" +openssl x509 -req -in ca.csr -text -extfile /etc/ssl/openssl.cnf -extensions v3_ca -signkey ca-key.pem -out ca-cert.pem +openssl req -new -nodes -text -out server.csr -keyout server-key.pem -subj "/CN=pg-server" +openssl x509 -req -in server.csr -text -CA ca-cert.pem -CAkey ca-key.pem -CAcreateserial -out server-cert.pem +openssl req -new -nodes -text -out client.csr -keyout client-key.pem -subj "/CN=crystal_ssl" +openssl x509 -req -in client.csr -text -CA ca-cert.pem -CAkey ca-key.pem -CAcreateserial -out client-cert.pem +chmod 600 * + +cp ca-cert.pem root.crt +mv client-cert.pem crystal_ssl.crt +mv client-key.pem crystal_ssl.key +openssl verify -CAfile root.crt crystal_ssl.crt + +cp server-cert.pem $CONF +cp server-key.pem $CONF/ +cp ca-cert.pem $CONF/ +chown postgres $CONF/*.pem +echo "ssl = on" >> $CONF/postgresql.conf +echo "ssl_cert_file = '$CONF/server-cert.pem'" >> $CONF/postgresql.conf +echo "ssl_key_file = '$CONF/server-key.pem'" >> $CONF/postgresql.conf +echo "ssl_ca_file = '$CONF/ca-cert.pem'" >> $CONF/postgresql.conf + +pg_ctlcluster $VER main restart diff --git a/lib/pg/.gitignore b/lib/pg/.gitignore new file mode 100644 index 00000000..4c55ba9e --- /dev/null +++ b/lib/pg/.gitignore @@ -0,0 +1,8 @@ +.crystal +.shards +shard.lock +lib +tmp +test*.cr +spec/.run_auth_specs +.cert diff --git a/lib/pg/CHANGELOG b/lib/pg/CHANGELOG new file mode 100644 index 00000000..e6e69091 --- /dev/null +++ b/lib/pg/CHANGELOG @@ -0,0 +1,188 @@ +v?.?.? upcoming +===================== + +v0.18.1 2018-08-09 +===================== +* bugfix: Fix Time encoding for non-zero-offset times (thanks @straight-shoota) + +v0.18.0 2019-08-04 +===================== +* Add SCRAM-SHA-256 support + * NOTE: SASLPrep is missing as of this realease, so not all passwords work +* Fix reading large number of UUIDs (thanks @asterite) +* Correctly encode timestamp values to consider microseconds (thanks @asterite) +* Update crystal db support to 0.6.0 (thanks @bcardiff) + +v0.17.0 2019-07-19 +===================== +* Add `Enumerable` channels overload for `PG.connect_listen` (thanks @vladfaust) +* Fixes Time.new deprecation warning (thanks @bcardiff) +* Ensure PQ::Connection#do_close does not raise (thanks @bcardiff) +* Fix IOError on closed connection (thanks @omarroth) +* Array fixes (thanks @asterite) + * properly handle exceptions during decoding array + * add Numeric#inspect + * prevent decoding array of numeric as floats + * internal: map each type to decoder + +v0.16.1 2019-04-15 +===================== +* Support Crystal v0.28.0 (thanks @bcardiff) +* Fix support for reading a NoticeResponse at row start (thanks @straight-shoota) + +v0.16.0 2019-04-02 +===================== +* Support connection negotiation without BackendKeyData frame (thanks @rx14) +* Make connection client encoding check case insesitive (thanks @yumoose) +* Support client cert auth (thanks @sanderhahn) +* Fix reading null when expecting array (thanks @straight-shoota) +* Fix encoding string arrays with special characters (thanks @straight-shoota) +* Fix version parser (thanks @straight-shoota) + +v0.15.0 2018-06-15 +===================== +* Support Crystal v0.25 (thanks @greenbigfrog) +* fix PG::Numeric#to_s on numbers with >= weight than ndigits, eg 800000 (#133) +* fix PG::Numeric#to_s on numbers where the `digits` expected to be zero padded + and other cases (#134) +* Set cause when raising DB::ConnectionRefused (thanks @rx14) + +v0.14.1 2017-12-26 +===================== +* Update crystal db support to 0.5.0 (thanks @bcardiff) + +v0.14.0 2017-12-26 +===================== +* Support Crystal v0.24, breaks support for older versions +* Support Postgres v10.0+ new two digit version scheme in #version + +v0.13.4 2017-10-16 +===================== +* fix CI (thanks @waghanza) +* bugfix: make sure to read all nulls (thanks @ZeWebDev) +* bugfix: no longer hangs on unhandeled exceptions (thanks @bigtunacan) + +v0.13.3 2017-03-21 +===================== +* Increased precision when encoding times +* Use DB.connect in ListenConnection to avoid creating a connection pool (thanks @bcardiff) +* Updates to use db 0.4.0. (thanks again @bcardiff) + +v0.13.2 2017-02-21 +===================== +* update to crystal 0.21.0 (thanks @felipeelias) + +v0.13.1 2016-12-25 +===================== +* update to crystal 0.20.3 (thanks @bcardiff) + +v0.13.0 2016-12-21 +===================== +* Update to support crystal-db's connection pooling (thanks @bcardiff) + +v0.12.0 2016-12-10 +===================== +* Uses crystal-db api (thanks @asterite) + +v0.11.0 2016-09-04 +===================== +* Adds dedicated LISTEN connection + +v0.10.0 2016-09-03 +===================== +* Adds support for array types + +v0.9.1 2016-08-23 +===================== +* Wrap query execution in a mutext to prevent protocol desynchronization + +v0.9.0 2016-07-26 +===================== +* remove support for plaintext auth to prevent downgrade mitm attacks +* fix for multibyte characters in query strings + +v0.8.0 2016-06-17 +===================== +* (breaking) geo points now all have own types (thanks @asterite) +* support crystal 0.18 + +v0.7.1 2016-05-15 +===================== +* (breaking) unknown oids decoded as byte slices instead of string +* fix shard.yml +* Adds optional BigRational extension to PG::Numeric +* bugfixes in protocol +* adds geo types + +v0.7.0 2016-05-05 +===================== +new features + * 100% crystal, using crystal's native async io (no longer using libpq) + * each Connection#exec form can take a block, and stream the rows as they + come in. This does not store any rows in memory and is very fast. + * adds #on_notification for listen/notify support + +incompatable changes: + * Result.each now yields the entire row and an array of fields + * The notice callback now yields an entire + * Error classes have changed, some removed + * "db_name" changed to "dbname" in hash #initialize to match postgres + * on_notice callback gets an object instead of just a string + +v0.6.1 2016-05-15 +==================== +* fix shard.yml + +v0.6.0 2016-05-04 +==================== +* Adds Adds on_notice callback (thanks @radiospiel) +* Adds PG::Numeric for numeric/decimal support +* Note: this will be the last release that links LibPQ + +v0.5.0 2015-12-21 +==================== +* Adds Result#each which allows to map a PG result to a struct or class, + avoiding temporary memory structures (thanks @ysbaddaden) +* Connection#exec is now async (thanks @ysbaddaden) + +v0.4.3 2015-10-21 +==================== +* Support for byta (thanks @jhass) +* `Connection`s can be made with a hash of params (thanks @tebakane) +* Support Crystal 0.9.0 and bytea encoding fix (thanks @technorama) + +v0.4.2 2015-09-29 +==================== +* fix UUIDs to have appropriate dashes + +v0.4.1 2015-09-29 +==================== +* Fix UUID type + +v0.4.0 2015-09-19 +==================== +* Fix and require Crystal 0.8.0 +* Allow #to_hash with typed querying interface (thanks @werner) + +v0.3.2 2015-08-28 +==================== +* Connection#exec_all (thanks @solisoft) +* BUGFIX: bigint and smallint were broken since v0.3.0 + +v0.3.1 2015-08-25 +==================== +* BUGFIX: previous release would give wrong reults under --release +* 3x fater again on time parsing, total ~10x faster than text format + +v0.3.0 2015-08-24 +==================== +* switch to much faster binary result format (time parsing 3.4x faster than text) +* decoders are now pluggable with PG::Decoder.register_decoder + +v0.2.0 2015-07-25 +==================== +* Add Connection#escape_literal and Connection#escape_identifier + +v0.1.0 2015-06-14 +==================== +* first named version using crystal shards diff --git a/lib/pg/CONTRIBUTING b/lib/pg/CONTRIBUTING new file mode 100644 index 00000000..47142cd4 --- /dev/null +++ b/lib/pg/CONTRIBUTING @@ -0,0 +1,3 @@ +Make sure the change is properly tested and the build passes. +Submit a PR. +Thanks! diff --git a/lib/pg/CONTRIBUTORS b/lib/pg/CONTRIBUTORS new file mode 100644 index 00000000..32d3eeff --- /dev/null +++ b/lib/pg/CONTRIBUTORS @@ -0,0 +1,23 @@ +Will Leinweber +Pedro Belo +Werner Echezuría +Maciek Sakrejda +Olivier BONNAURE +Jonne Haß +technorama +tebakane +Julien Portalier +radiospiel +Ary Borenszweig +Brian J. Cardiff +Felipe Philipp +Marwan Rabbâa +Joiey Seeley +ZeWebDev +Chris Hobbs +greenbigfrog +yumoose +Sander Hahn +Johannes Müller +Vlad Faust +Omar Roth diff --git a/lib/pg/LICENSE b/lib/pg/LICENSE new file mode 100644 index 00000000..8baf8820 --- /dev/null +++ b/lib/pg/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2015, Will Leinweber +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of crystal-pg nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/lib/pg/README.md b/lib/pg/README.md new file mode 100644 index 00000000..ec8caa7f --- /dev/null +++ b/lib/pg/README.md @@ -0,0 +1,83 @@ +# crystal-pg +A native, non-blocking Postgres driver for Crystal + +[![Build Status](https://circleci.com/gh/will/crystal-pg/tree/master.svg?style=svg)](https://circleci.com/gh/will/crystal-pg/tree/master) + + +## usage + +This driver now uses the `crystal-db` project. Documentation on connecting, +querying, etc, can be found at: + +* https://crystal-lang.org/docs/database/ +* https://crystal-lang.org/docs/database/connection_pool.html + +### shards + +Add this to your `shard.yml` and run `shards install` + +``` yml +dependencies: + pg: + github: will/crystal-pg +``` + +### Listen/Notify + +There are two ways to listen for notifications. For docs on `NOTIFY`, please +read . + +1. Any connection can be given a callback to run on notifications. However they + are only received when other traffic is going on. +2. A special listen-only connection can be established for instant notification + processing with `PG.connect_listen`. + +``` crystal +# see full example in examples/listen_notify.cr +PG.connect_listen("postgres:///", "a", "b") do |n| # connect and listen on "a" and "b" + puts " got: #{n.payload} on #{n.channel}" # print notifications as they come in +end +``` + +### Arrays + +Crystal-pg supports several popular array types. If you only need a 1 +dimensional array, you can cast down to the appropriate Crystal type: + +``` crystal +PG_DB.query_one("select ARRAY[1, null, 3]", &.read(Array(Int32?)) +# => [1, nil, 3] + +PG_DB.query_one("select '{hello, world}'::text[]", &.read(Array(String)) +# => ["hello", "world"] +``` + +## Requirements + +Crystal-pg is [regularly tested on](https://circleci.com/gh/will/crystal-pg) +the Postgres versions the [Postgres project itself supports](https://www.postgresql.org/support/versioning/). +Since it uses protocol version 3, older versions probably also work but are not guaranteed. + +## Supported Datatypes + +- text +- boolean +- int8, int4, int2 +- float4, float8 +- timestamptz, date, timestamp (but no one should use ts when tstz exists!) +- json and jsonb +- uuid +- bytea +- numeric/decimal (1) +- varchar +- regtype +- geo types: point, box, path, lseg, polygon, circle, line +- array types: int8, int4, int2, float8, float4, bool, text, numeric, timestamptz, date, timestamp + +1: A note on numeric: In Postgres this type has arbitrary precision. In this + driver, it is represented as a `PG::Numeric` which retains all precision, but + if you need to do any math on it, you will probably need to cast it to a + float first. If you need true arbitrary precision, you can optionally + require `pg_ext/big_rational` which adds `#to_big_r`, but requires that you + have LibGMP installed. + diff --git a/lib/pg/examples/command b/lib/pg/examples/command new file mode 100755 index 00000000..56674159 --- /dev/null +++ b/lib/pg/examples/command @@ -0,0 +1,6 @@ +#!/usr/bin/env crystal +require "../src/pg" + +DB.open(ENV["DATABASE_URL"]) do |db| + p db.query_one(ARGV[0], &.read) +end diff --git a/lib/pg/examples/listen_notify.cr b/lib/pg/examples/listen_notify.cr new file mode 100644 index 00000000..97998d3b --- /dev/null +++ b/lib/pg/examples/listen_notify.cr @@ -0,0 +1,39 @@ +require "../src/pg" + +PG.connect_listen("postgres:///", "a", "b") do |n| # connect and listen on "a" and "b" + puts " got: #{n.payload} on #{n.channel}" # print notifications as they come in +end + +PG_DB = DB.open("postgres:///") # make a normal connection +spawn do # spawn a coroutine + 10.times do |i| # + chan = rand > 0.5 ? "a" : "b" # pick a channel + puts "sending: #{i}" # prints always before "got:" + PG_DB.exec("SELECT pg_notify($1, $2)", [chan, i]) # send notification + puts " sent: #{i}" # may print before or after "got:" + sleep 1 + end +end + +sleep 6 # # wait a bit before exiting + +# Example output. Ordering and channels will vary. +# +# sending: 0 +# sent: 0 +# got: 0 on a +# sending: 1 +# got: 1 on a +# sent: 1 +# sending: 2 +# sent: 2 +# got: 2 on a +# sending: 3 +# sent: 3 +# got: 3 on b +# sending: 4 +# sent: 4 +# got: 4 on b +# sending: 5 +# got: 5 on a +# sent: 5 diff --git a/lib/pg/examples/shoddy_psql.cr b/lib/pg/examples/shoddy_psql.cr new file mode 100755 index 00000000..d7602080 --- /dev/null +++ b/lib/pg/examples/shoddy_psql.cr @@ -0,0 +1,72 @@ +#!/usr/bin/env crystal +require "readline" +require "../src/pg" + +url = ARGV[0]? || "postgres:///" +db = DB.open(url) + +loop do + query = Readline.readline("# ", true) || "" + has_results = false + begin + db.query(query) do |rs| + has_results = rs.column_count > 0 + if has_results + # Gather rows, including a first row for the column names + rows = [] of Array(typeof(rs.read)) + + # The first row: column names + rows << rs.column_count.times.map { |i| rs.column_name(i).as(typeof(rs.read)) }.to_a + + # The result rows + rs.each do + rows << rs.column_count.times.map { rs.read }.to_a + end + + # Compute maximum sizes for each column for a nicer output + sizes = [] of Int32 + rs.column_count.times do |i| + # Add 2 for padding + sizes << rows.max_of(&.[i].to_s.size.+(2)) + end + + # Print rows + rows.each_with_index do |row, row_index| + row.each_with_index do |value, col_index| + print " |" if col_index > 0 + print " " + col_size = sizes[col_index] - 2 + case value + when Int, Float, PG::Numeric + print value.to_s.rjust(col_size) + else + print value.to_s.ljust(col_size) + end + end + + # Write the separator ("---+---+---") after the first row + if row_index == 0 + puts + row.each_index do |col_index| + print "+" if col_index > 0 + print "-" * sizes[col_index] + end + end + + puts + end + + # Print numbers of rows + count = rows.size - 1 + if count == 1 + puts "(1 row)" + else + puts "(#{count} rows)" + end + end + end + rescue e + puts "ERROR: #{e.message}" + end + puts if has_results +end diff --git a/lib/pg/shard.yml b/lib/pg/shard.yml new file mode 100644 index 00000000..83520ae7 --- /dev/null +++ b/lib/pg/shard.yml @@ -0,0 +1,7 @@ +name: pg +version: 0.18.1 + +dependencies: + db: + github: crystal-lang/crystal-db + version: ~> 0.6.0 diff --git a/lib/pg/spec/pg/connection_spec.cr b/lib/pg/spec/pg/connection_spec.cr new file mode 100644 index 00000000..3dedb40d --- /dev/null +++ b/lib/pg/spec/pg/connection_spec.cr @@ -0,0 +1,106 @@ +require "../spec_helper" + +describe PG::Connection, "#initialize" do + it "raises on bad connections" do + expect_raises(DB::ConnectionRefused) { + DB.open("postgres://localhost:5433") + } + end +end + +describe PG::Connection, "#on_notice" do + it "sends notices to on_notice" do + last_notice = nil + PG_DB.using_connection do |conn| + conn.on_notice do |notice| + last_notice = notice + end + end + + PG_DB.using_connection do |conn| + conn.exec_all <<-SQL + SET client_min_messages TO notice; + DO language plpgsql $$ + BEGIN + RAISE NOTICE 'hello, world!'; + END + $$; + SQL + end + + last_notice.should_not eq(nil) + last_notice.to_s.should eq("NOTICE: hello, world!\n") + end +end + +describe PG::Connection, "#on_notification" do + it "does listen/notify within same connection" do + last_note = nil + with_db do |db| + db.using_connection do |conn| + conn.on_notification { |note| last_note = note } + + conn.exec("listen somechannel") + conn.exec("notify somechannel, 'do a thing'") + end + end + + last_note.not_nil!.channel.should eq("somechannel") + last_note.not_nil!.payload.should eq("do a thing") + end +end + +describe PG, "#listen" do + it "opens a special listen only connection" do + got = false + ch = Channel(Nil).new + conn = PG.connect_listen(DB_URL, "foo", "bar") do |n| + got = true + ch.send(nil) + end + + begin + got.should eq(false) + + PG_DB.exec("notify wrong, 'hello'") + got.should eq(false) + + PG_DB.exec("notify foo, 'hello'") + ch.receive + got.should eq(true) + got = false + + PG_DB.exec("notify bar, 'hello'") + ch.receive + got.should eq(true) + ensure + conn.close + end + end +end + +describe PG, "#read_next_row_start" do + it "handles reading a notice" do + with_connection do |db| + db.exec "SET client_min_messages TO notice" + db.exec <<-SQL + CREATE OR REPLACE FUNCTION foo() RETURNS integer AS $$ + BEGIN + RAISE NOTICE 'foo'; + RAISE NOTICE 'bar'; + RETURN 42; + END; + $$ LANGUAGE plpgsql; + SQL + + received_notices = [] of String + db.on_notice do |notice| + received_notices << notice.message + end + db.scalar("SELECT foo()").should eq 42 + received_notices.should eq ["foo", "bar"] + + db.exec("DROP FUNCTION foo()") + end + end +end diff --git a/lib/pg/spec/pg/decoder_spec.cr b/lib/pg/spec/pg/decoder_spec.cr new file mode 100644 index 00000000..a5ec4fd0 --- /dev/null +++ b/lib/pg/spec/pg/decoder_spec.cr @@ -0,0 +1,83 @@ +require "../spec_helper" + +describe PG::Decoders do + # name, sql, result + test_decode "undefined ", "'what' ", "what" + test_decode "text ", "'what'::text ", "what" + test_decode "varchar ", "'wh'::varchar", "wh" + test_decode "empty strings", "'' ", "" + test_decode "null as nil ", "null ", nil + test_decode "boolean false", "false ", false + test_decode "boolean true ", "true ", true + test_decode "int2 smallint", "1::int2 ", 1 + test_decode "int4 int ", "1::int4 ", 1 + test_decode "int8 bigint ", "1::int8 ", 1 + test_decode "float ", "-0.123::float", -0.123 + test_decode "regtype ", "pg_typeof(3) ", 23 + + test_decode "double prec.", "'35.03554004971999'::float8", 35.03554004971999 + test_decode "flot prec.", "'0.10000122'::float4", 0.10000122_f32 + + test_decode "bytea", "E'\\\\001\\\\134\\\\176'::bytea", + Slice(UInt8).new(UInt8[0o001, 0o134, 0o176].to_unsafe, 3) + test_decode "bytea", "E'\\\\005\\\\000\\\\377\\\\200'::bytea", + Slice(UInt8).new(UInt8[5, 0, 255, 128].to_unsafe, 4) + test_decode "bytea empty", "E''::bytea", + Slice(UInt8).new(UInt8[].to_unsafe, 0) + + test_decode "uuid", "'7d61d548124c4b38bc05cfbb88cfd1d1'::uuid", + "7d61d548-124c-4b38-bc05-cfbb88cfd1d1" + test_decode "uuid", "'7d61d548-124c-4b38-bc05-cfbb88cfd1d1'::uuid", + "7d61d548-124c-4b38-bc05-cfbb88cfd1d1" + + if Helper.db_version_gte(9, 2) + test_decode "json", %('[1,"a",true]'::json), JSON.parse(%([1,"a",true])) + test_decode "json", %('{"a":1}'::json), JSON.parse(%({"a":1})) + end + if Helper.db_version_gte(9, 4) + test_decode "jsonb", "'[1,2,3]'::jsonb", JSON.parse("[1,2,3]") + end + + test_decode "timestamptz", "'2015-02-03 16:15:13-01'::timestamptz", + Time.utc(2015, 2, 3, 17, 15, 13) + + test_decode "timestamptz", "'2015-02-03 16:15:14.23-01'::timestamptz", + Time.utc(2015, 2, 3, 17, 15, 14, nanosecond: 230_000_000) + + test_decode "timestamp", "'2015-02-03 16:15:15'::timestamp", + Time.utc(2015, 2, 3, 16, 15, 15) + + test_decode "date", "'2015-02-03'::date", + Time.utc(2015, 2, 3, 0, 0, 0) + + it "numeric" do + x = ->(q : String) do + PG_DB.query_one "select '#{q}'::numeric", &.read(PG::Numeric) + end + x.call("1.3").to_f.should eq(1.3) + x.call("nan").nan?.should be_true + end + + it "decodes many uuids (#148)" do + uuid = "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + ids = PG_DB.query_all("select '#{uuid}'::uuid from generate_series(1,1000)", as: String) + ids.uniq.should eq([uuid]) + end + + test_decode "xml", "'false'::xml", "false" + test_decode "char", %('c'::"char"), 'c' + test_decode "bpchar", %('c'::char), "c" + test_decode "bpchar", %('c'::char(5)), "c " + test_decode "name", %('hi'::name), "hi" + test_decode "oid", %(2147483648::oid), 2147483648_u32 + test_decode "point", "'(1.2,3.4)'::point", PG::Geo::Point.new(1.2, 3.4) + if Helper.db_version_gte(9, 4) + test_decode "line ", "'(1,2,3,4)'::line ", PG::Geo::Line.new(1.0, -1.0, 1.0) + test_decode "line ", "'1,2,3'::circle ", PG::Geo::Circle.new(1.0, 2.0, 3.0) + end + test_decode "lseg ", "'(1,2,3,4)'::lseg ", PG::Geo::LineSegment.new(1.0, 2.0, 3.0, 4.0) + test_decode "box ", "'(1,2,3,4)'::box ", PG::Geo::Box.new(1.0, 2.0, 3.0, 4.0) + test_decode "path ", "'(1,2,3,4)'::path ", PG::Geo::Path.new([PG::Geo::Point.new(1.0, 2.0), PG::Geo::Point.new(3.0, 4.0)], closed: true) + test_decode "path ", "'[1,2,3,4,5,6]'::path", PG::Geo::Path.new([PG::Geo::Point.new(1.0, 2.0), PG::Geo::Point.new(3.0, 4.0), PG::Geo::Point.new(5.0, 6.0)], closed: false) + test_decode "polygon", "'1,2,3,4,5,6'::polygon", PG::Geo::Polygon.new([PG::Geo::Point.new(1.0, 2.0), PG::Geo::Point.new(3.0, 4.0), PG::Geo::Point.new(5.0, 6.0)]) +end diff --git a/lib/pg/spec/pg/decoders/array_decoder_spec.cr b/lib/pg/spec/pg/decoders/array_decoder_spec.cr new file mode 100644 index 00000000..d7dc5940 --- /dev/null +++ b/lib/pg/spec/pg/decoders/array_decoder_spec.cr @@ -0,0 +1,106 @@ +require "../../spec_helper" + +describe PG::Decoders do + test_decode "array", "'{}'::integer[]", [] of Int32 + test_decode "array", "ARRAY[9]", [9] + test_decode "array", "ARRAY[8,9]", [8, 9] + test_decode "array", "'{{9,8},{7,6},{5,4}}'::integer[]", + [[9, 8], [7, 6], [5, 4]] + test_decode "array", "'{ {9,8,7}, {6,5,4} }'::integer[] ", + [[9, 8, 7], [6, 5, 4]] + test_decode "array", "'{{{1,2},{3,4}},{{9,8},{7,6}}}'::integer[]", + [[[1, 2], [3, 4]], [[9, 8], [7, 6]]] + test_decode "array", "ARRAY[1, null, 2] ", [1, nil, 2] + test_decode "array", "('[3:5]={1,2,3}'::integer[])", [nil, nil, 1, 2, 3] + + it "allows special-case casting on simple arrays" do + value = PG_DB.query_one("select '{}'::integer[]", &.read(Array(Int32))) + typeof(value).should eq(Array(Int32)) + value.empty?.should be_true + + value = PG_DB.query_one("select '{1,2,3}'::integer[]", &.read(Array(Int32))) + typeof(value).should eq(Array(Int32)) + value.should eq([1, 2, 3]) + + value = PG_DB.query_one("select '{1,2,3,null}'::integer[]", &.read(Array(Int32?))) + typeof(value).should eq(Array(Int32?)) + value.should eq([1, 2, 3, nil]) + + value = PG_DB.query_one("select '{{1,2,3},{4,5,6}}'::integer[]", &.read(Array(Array(Int32)))) + typeof(value).should eq(Array(Array(Int32))) + value.should eq([[1, 2, 3], [4, 5, 6]]) + end + + it "reads array as nilable" do + value = PG_DB.query_one("select '{1,2,3}'::integer[]", &.read(Array(Int32)?)) + typeof(value).should eq(Array(Int32)?) + value.should eq([1, 2, 3]) + + value = PG_DB.query_one("select null", &.read(Array(Int32)?)) + typeof(value).should eq(Array(Int32)?) + value.should be_nil + end + + it "reads arrray of numeric" do + value = PG_DB.query_one("select '{1,2,3}'::numeric[]", &.read(Array(PG::Numeric))) + typeof(value).should eq(Array(PG::Numeric)) + value.map(&.to_f).should eq([1, 2, 3]) + end + + it "reads arrray of nilable numeric" do + value = PG_DB.query_one("select '{1,null,3}'::numeric[]", &.read(Array(PG::Numeric?))) + typeof(value).should eq(Array(PG::Numeric?)) + value.map(&.try &.to_f).should eq([1, nil, 3]) + end + + it "raises when reading null in non-null array" do + expect_raises(PG::RuntimeError) do + PG_DB.query_one("select '{1,2,3,null}'::integer[]", &.read(Array(Int32))) + end + end + + it "reads array of time" do + values = PG_DB.query_one("select array[to_date('20170103', 'YYYYMMDD')::timestamp]", &.read(Array(Time))) + typeof(values).should eq(Array(Time)) + + values.size.should eq(1) + values[0].should eq(Time.utc(2017, 1, 3)) + end + + it "reads array of date" do + values = PG_DB.query_one("select array[to_date('20170103', 'YYYYMMDD')]", &.read(Array(Time))) + typeof(values).should eq(Array(Time)) + + values.size.should eq(1) + values[0].should eq(Time.utc(2017, 1, 3)) + end + + it "raises when reading incorrect array type" do + expect_raises(PG::RuntimeError) do + PG_DB.query_one("select '{1,2,3}'::numeric[]", &.read(Array(Float64))) + end + end + + it "errors on negative lower bounds" do + expect_raises(PG::RuntimeError) do + PG_DB.query_one("select '[-2:-0]={1,2,3}'::integer[]", &.read) + end + end + + test_decode "bool array", "$${t,f,t}$$::bool[]", [true, false, true] + test_decode "char array", "$${a, b}$$::\"char\"[]", ['a', 'b'] + test_decode "int2 array", "$${1,2}$$::int2[]", [1, 2] + test_decode "text array", "$${hello, world}$$::text[]", ["hello", "world"] + test_decode "int8 array", "$${1,2}$$::int8[]", [1, 2] + test_decode "float4 array", "$${1.1,2.2}$$::float4[]", [1.1_f32, 2.2_f32] + test_decode "float8 array", "$${1.1,2.2}$$::float8[]", [1.1_f64, 2.2_f64] + test_decode "date array", "array[to_date('20170103', 'YYYYMMDD')]", [Time.utc(2017, 1, 3)] + test_decode "numeric array", "array[1::numeric]", [PG::Numeric.new(ndigits: 1, weight: 0, sign: PG::Numeric::Sign::Pos.value, dscale: 0, digits: [1] of Int16)] + test_decode "time array", "array[to_date('20170103', 'YYYYMMDD')::timestamp]", [Time.utc(2017, 1, 3)] + + it "errors when expecting array returns null" do + expect_raises(PG::RuntimeError, "unexpected NULL, expecting to read Array(String)") do + PG_DB.query_one("SELECT NULL", as: Array(String)) + end + end +end diff --git a/lib/pg/spec/pg/driver_spec.cr b/lib/pg/spec/pg/driver_spec.cr new file mode 100644 index 00000000..07509f98 --- /dev/null +++ b/lib/pg/spec/pg/driver_spec.cr @@ -0,0 +1,169 @@ +require "../spec_helper" + +def assert_single_read(rs, value_type, value) + rs.move_next.should be_true + rs.read(value_type).should eq(value) + rs.move_next.should be_false +end + +class NotSupportedType +end + +struct StructWithMapping + DB.mapping(a: Int32, b: Int32) +end + +describe PG::Driver do + it "should register postgres name" do + DB.driver_class("postgres").should eq(PG::Driver) + end + + it "exectes and selects value" do + PG_DB.query "select 123::int4" do |rs| + assert_single_read rs, Int32, 123 + end + end + + it "gets column count" do + PG_DB.query "select 1::int4, 1::int4" do |rs| + rs.column_count.should eq(2) + end + end + + it "gets column names" do + PG_DB.query "select 1::int4 as foo, 1::int4 as bar" do |rs| + rs.column_name(0).should eq("foo") + rs.column_name(1).should eq("bar") + end + end + + it "should raise an exception if unique constraint is violated" do + expect_raises(PQ::PQError) do + PG_DB.exec "drop table if exists contacts" + PG_DB.exec "create table contacts (name varchar(256), CONSTRAINT key_name UNIQUE(name))" + + result = PG_DB.query "insert into contacts values ($1)", "Foo" + result = PG_DB.query "insert into contacts values ($1)", "Foo" do |rs| + rs.move_next + end + end + end + + it "executes insert" do + PG_DB.exec "drop table if exists contacts" + PG_DB.exec "create table contacts (name varchar(256), age int4)" + + result = PG_DB.exec "insert into contacts values ($1, $2)", "Foo", 10 + + result.last_insert_id.should eq(0) # postgres doesn't support this + result.rows_affected.should eq(1) + end + + it "executes insert via query" do + PG_DB.query("drop table if exists contacts") do |rs| + rs.move_next.should be_false + end + end + + it "executes update" do + PG_DB.exec "drop table if exists contacts" + PG_DB.exec "create table contacts (name varchar(256), age int4)" + + PG_DB.exec "insert into contacts values ($1, $2)", "Foo", 10 + PG_DB.exec "insert into contacts values ($1, $2)", "Baz", 10 + PG_DB.exec "insert into contacts values ($1, $2)", "Baz", 20 + + result = PG_DB.exec "update contacts set age = 30 where age = 10" + + result.last_insert_id.should eq(0) # postgres doesn't support this + result.rows_affected.should eq(2) + end + + it "traverses result set" do + PG_DB.exec "drop table if exists contacts" + PG_DB.exec "create table contacts (name varchar(256), age int4)" + + PG_DB.exec "insert into contacts values ($1, $2)", "Foo", 10 + PG_DB.exec "insert into contacts values ($1, $2)", "Bar", 20 + + PG_DB.query "select name, age from contacts order by age" do |rs| + rs.move_next.should be_true + rs.read(String).should eq("Foo") + rs.move_next.should be_true + rs.read(String).should eq("Bar") + rs.move_next.should be_false + end + end + + describe "transactions" do + it "can read inside transaction and rollback after" do + with_db do |db| + db.exec "drop table if exists person" + db.exec "create table person (name varchar(25))" + db.transaction do |tx| + tx.connection.scalar("select count(*) from person").should eq(0) + tx.connection.exec "insert into person (name) values ($1)", "John Doe" + tx.connection.scalar("select count(*) from person").should eq(1) + tx.rollback + end + db.scalar("select count(*) from person").should eq(0) + end + end + + it "can read inside transaction or after commit" do + with_db do |db| + db.exec "drop table if exists person" + db.exec "create table person (name varchar(25))" + db.transaction do |tx| + tx.connection.scalar("select count(*) from person").should eq(0) + tx.connection.exec "insert into person (name) values ($1)", "John Doe" + tx.connection.scalar("select count(*) from person").should eq(1) + # using other connection + db.scalar("select count(*) from person").should eq(0) + end + db.scalar("select count(*) from person").should eq(1) + end + end + end + + describe "nested transactions" do + it "can read inside transaction and rollback after" do + with_db do |db| + db.exec "drop table if exists person" + db.exec "create table person (name varchar(25))" + db.transaction do |tx_0| + tx_0.connection.scalar("select count(*) from person").should eq(0) + tx_0.connection.exec "insert into person (name) values ($1)", "John Doe" + tx_0.transaction do |tx_1| + tx_1.connection.exec "insert into person (name) values ($1)", "Sarah" + tx_1.connection.scalar("select count(*) from person").should eq(2) + tx_1.transaction do |tx_2| + tx_2.connection.exec "insert into person (name) values ($1)", "Jimmy" + tx_2.connection.scalar("select count(*) from person").should eq(3) + tx_2.rollback + end + end + tx_0.connection.scalar("select count(*) from person").should eq(2) + tx_0.rollback + end + db.scalar("select count(*) from person").should eq(0) + end + end + end + + describe "move_next" do + it "properly skips null columns" do + no_nulls = StructWithMapping.from_rs(PG_DB.query("select 1 as a, 1 as b")).first + {no_nulls.a, no_nulls.b}.should eq({1, 1}) + + message = "PG::ResultSet#read returned a Nil. A Int32 was expected." + expect_raises(Exception, message) do + StructWithMapping.from_rs(PG_DB.query("select 2 as a, null as b")) + end + + expect_raises(Exception, message) do # importantly not an IndexError: Index out of bounds + StructWithMapping.from_rs(PG_DB.query("select null as a, null as b")) + end + end + end +end diff --git a/lib/pg/spec/pg/encoder_spec.cr b/lib/pg/spec/pg/encoder_spec.cr new file mode 100644 index 00000000..d326d02e --- /dev/null +++ b/lib/pg/spec/pg/encoder_spec.cr @@ -0,0 +1,65 @@ +require "../spec_helper" + +private def test_insert_and_read(datatype, value, file = __FILE__, line = __LINE__) + it "inserts #{datatype}", file, line do + PG_DB.exec "drop table if exists test_table" + PG_DB.exec "create table test_table (v #{datatype})" + + # Read casting the value + PG_DB.exec "insert into test_table values ($1)", [value] + actual_value = PG_DB.query_one "select v from test_table", as: value.class + actual_value.should eq(value) + + # Read without casting the value + actual_value = PG_DB.query_one "select v from test_table", &.read + actual_value.should eq(value) + end +end + +describe PG::Driver, "encoder" do + test_insert_and_read "int4", 123 + + test_insert_and_read "float", 12.34 + + test_insert_and_read "varchar", "hello world" + + test_insert_and_read "timestamp", Time.utc(2015, 2, 3, 17, 15, nanosecond: 13_000_000) + test_insert_and_read "timestamp", Time.utc(2015, 2, 3, 17, 15, 13, nanosecond: 11_000_000) + test_insert_and_read "timestamp", Time.utc(2015, 2, 3, 17, 15, 13, nanosecond: 123_456_000) + test_insert_and_read "timestamptz", Time.local(2019, 8, 13, 12, 30, location: Time::Location.fixed(-14_400)) + + test_insert_and_read "bool[]", [true, false, true] + + test_insert_and_read "float[]", [1.2, 3.4, 5.6] + + test_insert_and_read "integer[]", [] of Int32 + test_insert_and_read "integer[]", [1, 2, 3] + test_insert_and_read "integer[]", [[1, 2], [3, 4]] + + test_insert_and_read "text[]", ["t", "f", "t"] + test_insert_and_read "text[]", [%("a), %(\\b~), %(c\\"d), %(\uFF8F)] + test_insert_and_read "text[]", ["baz, bar"] + test_insert_and_read "text[]", ["foo}"] + + describe "geo" do + test_insert_and_read "point", PG::Geo::Point.new(1.2, 3.4) + if Helper.db_version_gte(9, 4) + test_insert_and_read "line", PG::Geo::Line.new(1.2, 3.4, 5.6) + end + test_insert_and_read "circle", PG::Geo::Circle.new(1.2, 3.4, 5.6) + test_insert_and_read "lseg", PG::Geo::LineSegment.new(1.2, 3.4, 5.6, 7.8) + test_insert_and_read "box", PG::Geo::Box.new(1.2, 3.4, 5.6, 7.8) + test_insert_and_read "path", PG::Geo::Path.new([ + PG::Geo::Point.new(1.2, 3.4), + PG::Geo::Point.new(5.6, 7.8), + ], closed: false) + test_insert_and_read "path", PG::Geo::Path.new([ + PG::Geo::Point.new(1.2, 3.4), + PG::Geo::Point.new(5.6, 7.8), + ], closed: true) + test_insert_and_read "polygon", PG::Geo::Polygon.new([ + PG::Geo::Point.new(1.2, 3.4), + PG::Geo::Point.new(5.6, 7.8), + ]) + end +end diff --git a/lib/pg/spec/pg/escape_helper_spec.cr b/lib/pg/spec/pg/escape_helper_spec.cr new file mode 100644 index 00000000..b1dba526 --- /dev/null +++ b/lib/pg/spec/pg/escape_helper_spec.cr @@ -0,0 +1,19 @@ +require "../spec_helper" + +describe PG::Connection, "#escape_literal" do + it { escape_literal(%(foo)).should eq(%('foo')) } + it { escape_literal(%(this has a \\)).should eq(%( E'this has a \\\\')) } + it { escape_literal(%(what's your "name")).should eq(%('what''s your "name"')) } + it { escape_literal(%(foo).to_slice).should eq(%('\\x666f6f')) } + # it "raises on invalid strings" do + # expect_raises(PG::ConnectionError) { escape_literal("\u{F4}") } + # end +end + +describe PG::Connection, "#escape_identifier" do + it { escape_identifier(%(foo)).should eq(%("foo")) } + it { escape_identifier(%(what's \\ your "name")).should eq(%("what's \\ your ""name""")) } + # it "raises on invalid strings" do + # expect_raises(PG::ConnectionError) { escape_identifier("\u{F4}") } + # end +end diff --git a/lib/pg/spec/pg/numeric_spec.cr b/lib/pg/spec/pg/numeric_spec.cr new file mode 100644 index 00000000..b1bbd78e --- /dev/null +++ b/lib/pg/spec/pg/numeric_spec.cr @@ -0,0 +1,138 @@ +require "../spec_helper" +require "../../src/pg_ext/big_rational" + +private def n(nd, w, s, ds, d) + PG::Numeric.new(nd.to_i16, w.to_i16, s.to_i16, ds.to_i16, d.map(&.to_i16)) +end + +private def br(n, d) + BigRational.new(n, d) +end + +private def ex(which) + case which + when "nan" + n(0, 0, -16384, 0, [] of Int16) + when "0" + n(0, 0, 0, 0, [] of Int16) + when "0.0" + n(0, 0, 0, 1, [] of Int16) + when "1" + n(1, 0, 0, 0, [1]) + when "-1" + n(1, 0, 0x4000, 0, [1]) + when "1.3" + n(2, 0, 0, 1, [1, 3000]) + when "1.30" + n(2, 0, 0, 2, [1, 3000]) + when "12345.6789123" + n(4, 1, 0, 7, [1, 2345, 6789, 1230]) + when "-0.00009" + n(1, -2, 0x4000, 5, [9000]) + when "-0.000009" + n(1, -2, 0x4000, 6, [900]) + when "-0.0000009" + n(1, -2, 0x4000, 7, [90]) + when "-0.00000009" + n(1, -2, 0x4000, 8, [9]) + when "0.0...9" + n(2, -10, 0, 43, [9999, 9990]) + when "800000" + n(1, 1, 0, 0, [80]) + when "50093" + n(2, 1, 0, 0, [5, 93]) + when "500000093" + n(3, 2, 0, 0, [5, 0, 93]) + when "0.3" + n(1, -1, 0, 1, [3000]) + when "0.03" + n(1, -1, 0, 2, [300]) + when "0.003" + n(1, -1, 0, 3, [30]) + when "0.000300003" + n(3, -1, 0, 9, [3, 0, 3000]) + when "0.0000006000000" + n(1, -2, 0, 13, [60]) + when "50093.60754417" + n(4, 1, 0, 8, [5, 93, 6075, 4417]) + else + raise "no example #{which}" + end +end + +describe PG::Numeric do + it "#to_f" do + [ + {"nan", 0_f64}, + {"0", 0_f64}, + {"0.0", 0_f64}, + {"1", 1_f64}, + {"-1", -1_f64}, + {"1.3", 1.3_f64}, + {"1.30", 1.3_f64}, + {"12345.6789123", 12345.6789123_f64}, + {"-0.00009", -0.00009_f64}, + {"-0.000009", -0.000009_f64}, + {"-0.0000009", -0.0000009_f64}, + {"-0.00000009", -0.00000009_f64}, + {"0.0...9", 0.0000000000000000000000000000000000009999999_f64}, + ].each do |x| + ex(x[0]).to_f.should be_close(x[1], 1e-50) + end + end + + it "#to_big_r" do + [ + {"nan", br(0, 1)}, + {"0", br(0, 1)}, + {"0.0", br(0, 1)}, + {"1", br(1, 1)}, + {"-1", br(-1, 1)}, {"1.3", br(13, 10)}, + {"1.30", br(13, 10)}, + {"12345.6789123", br(123456789123, 10000000)}, + {"-0.00009", br(-9, 100000)}, + {"-0.000009", br(-9, 1000000)}, + {"-0.0000009", br(-9, 10000000)}, + {"-0.00000009", br(-9, 100000000)}, + {"0.0...9", br(BigInt.new(9999999), BigInt.new(10)**43)}, + ].each do |x| + ex(x[0]).to_big_r.should eq(x[1]) + end + end + + it "#to_s" do + [ + {"nan", "NaN"}, + {"0", "0"}, + {"0.0", "0.0"}, + {"1", "1"}, + {"-1", "-1"}, + {"1.3", "1.3"}, + {"1.30", "1.30"}, + {"12345.6789123", "12345.6789123"}, + {"800000", "800000"}, + {"0.3", "0.3"}, + {"0.03", "0.03"}, + {"0.003", "0.003"}, + {"0.000300003", "0.000300003"}, + {"-0.00009", "-0.00009"}, + {"-0.000009", "-0.000009"}, + {"-0.0000009", "-0.0000009"}, + {"-0.00000009", "-0.00000009"}, + {"0.0...9", "0.0000000000000000000000000000000000009999999"}, + {"50093", "50093"}, + {"500000093", "500000093"}, + {"0.0000006000000", "0.0000006000000"}, + {"50093.60754417", "50093.60754417"}, + ].each do |x| + ex(x[0]).to_s.should eq(x[1]) + ex(x[0]).inspect.should eq(x[1]) + end + end + + it "#nan?" do + ex("nan").nan?.should be_true + ex("1").nan?.should be_false + ex("-1").nan?.should be_false + end +end diff --git a/lib/pg/spec/pq/authentication_methods_spec.cr b/lib/pg/spec/pq/authentication_methods_spec.cr new file mode 100644 index 00000000..57205d69 --- /dev/null +++ b/lib/pg/spec/pq/authentication_methods_spec.cr @@ -0,0 +1,101 @@ +require "../spec_helper" + +# The following specs requires specific lines in the local pg_hba.conf file +# * crystal_md5 user with md5 method +# * and if the line is trust for everything, it needs to be restricted to +# just your user +# Because of this, most of these specs are disabled by default. To enable them +# place an empty file called .run_auth_specs in /spec + +private def test_role(role, pass) + db = PG_DB.query_one("select current_database()", &.read) + url = "postgres://#{role}:#{pass}@127.0.0.1/#{db}" + DB.open(url) do |db| + db.query_one("select 1", &.read).should eq(1) + end +end + +describe PQ::Connection, "nologin role" do + it "raises" do + PG_DB.exec("drop role if exists crystal_test") + PG_DB.exec("create role crystal_test nologin") + expect_raises(DB::ConnectionRefused) { + DB.open("postgres://crystal_test@localhost") + } + PG_DB.exec("drop role if exists crystal_test") + end +end + +if File.exists?(File.join(File.dirname(__FILE__), "../.run_auth_specs")) + describe PQ::Connection, "scram auth" do + it "works when given the correct password" do + PG_DB.exec("drop role if exists crystal_scram") + PG_DB.exec("set password_encryption='scram-sha-256'") + PG_DB.exec("create role crystal_scram login encrypted password 'pass'") + + test_role("crystal_scram", "pass") + + PG_DB.exec("drop role if exists crystal_scram") + end + + it "fails with a bad password" do + PG_DB.exec("drop role if exists crystal_scram") + PG_DB.exec("set password_encryption='scram-sha-256'") + PG_DB.exec("create role crystal_scram login encrypted password 'pass'") + + expect_raises(DB::ConnectionRefused) { + test_role("crystal_scram", "wrong") + } + + expect_raises(DB::ConnectionRefused) { + test_role("crystal_scram", "") + } + + PG_DB.exec("drop role if exists crystal_scram") + end + end if Helper.db_version_gte(10) + + describe PQ::Connection, "md5 auth" do + it "works when given the correct password" do + PG_DB.exec("drop role if exists crystal_md5") + PG_DB.exec("set password_encryption='md5'") if Helper.db_version_gte(10) + PG_DB.exec("create role crystal_md5 login encrypted password 'pass'") + test_role("crystal_md5", "pass") + PG_DB.exec("drop role if exists crystal_md5") + end + + it "fails when given the wrong password" do + PG_DB.exec("drop role if exists crystal_md5") + PG_DB.exec("set password_encryption='md5'") if Helper.db_version_gte(10) + PG_DB.exec("create role crystal_md5 login encrypted password 'pass'") + + expect_raises(DB::ConnectionRefused) { + test_role("crystal_md5", "bad") + } + + expect_raises(DB::ConnectionRefused) { + test_role("crystal_md5", "") + } + + PG_DB.exec("drop role if exists crystal_md5") + end + end + + describe PQ::Connection, "ssl clientcert auth" do + it "works when using ssl clientcert" do + PG_DB.exec("drop role if exists crystal_ssl") + PG_DB.exec("create role crystal_ssl login encrypted password 'pass'") + db = PG_DB.query_one("select current_database()", &.read) + certs = File.join Dir.current, ".cert" + uri = "postgres://crystal_ssl@127.0.0.1/#{db}?sslmode=verify-full&sslcert=#{certs}/crystal_ssl.crt&sslkey=#{certs}/crystal_ssl.key&sslrootcert=#{certs}/root.crt" + DB.open(uri) do |db| + db.query_one("select current_user", &.read).should eq("crystal_ssl") + end + PG_DB.exec("drop role if exists crystal_ssl") + end + end +else + describe "auth specs" do + pending "skipped: see file for details" { } + end +end diff --git a/lib/pg/spec/pq/connection_spec.cr b/lib/pg/spec/pq/connection_spec.cr new file mode 100644 index 00000000..effab2cd --- /dev/null +++ b/lib/pg/spec/pq/connection_spec.cr @@ -0,0 +1,28 @@ +require "../spec_helper" + +module PG + class Connection + getter connection + end +end + +describe PQ::Connection, "#server_parameters" do + it "ParameterStatus frames in response to set are handeled" do + get = ->{ PG_DB.using_connection &.connection.server_parameters["standard_conforming_strings"] } + get.call.should eq("on") + PG_DB.exec "set standard_conforming_strings to on" + get.call.should eq("on") + PG_DB.exec "set standard_conforming_strings to off" + get.call.should eq("off") + PG_DB.exec "set standard_conforming_strings to default" + get.call.should eq("on") + end +end + +describe PQ::Connection do + it "handles empty queries" do + PG_DB.exec "" + PG_DB.query("") { } + PG_DB.query_one("select 1", &.read).should eq(1) + end +end diff --git a/lib/pg/spec/pq/conninfo_spec.cr b/lib/pg/spec/pq/conninfo_spec.cr new file mode 100644 index 00000000..2032afb0 --- /dev/null +++ b/lib/pg/spec/pq/conninfo_spec.cr @@ -0,0 +1,81 @@ +require "spec" +require "../../src/pq/conninfo" + +private def assert_default_params(ci) + (PQ::ConnInfo::SOCKET_SEARCH + ["localhost"]).should contain(ci.host) + ci.database.should_not eq(nil) + ci.user.should_not eq(nil) + ci.database.should eq(ci.user) + ci.password.should eq(nil) + ci.port.should eq(5432) + ci.sslmode.should eq(:prefer) +end + +private def assert_custom_params(ci) + ci.host.should eq("host") + ci.database.should eq("db") + ci.user.should eq("user") + ci.password.should eq("pass") + ci.port.should eq(5555) + ci.sslmode.should eq(:require) +end + +private def assert_ssl_params(ci) + ci.sslmode.should eq(:"verify-full") + ci.sslcert.should eq("postgresql.crt") + ci.sslkey.should eq("postgresql.key") + ci.sslrootcert.should eq("root.crt") +end + +describe PQ::ConnInfo, "parts" do + it "can have all defaults" do + ci = PQ::ConnInfo.new + assert_default_params ci + end + + it "can take settings" do + ci = PQ::ConnInfo.new("host", "db", "user", "pass", 5555, :require) + assert_custom_params ci + end +end + +describe PQ::ConnInfo, ".from_conninfo_string" do + it "parses short postgres urls" do + ci = PQ::ConnInfo.from_conninfo_string("postgres:///") + assert_default_params ci + end + + it "parses postgres urls" do + ci = PQ::ConnInfo.from_conninfo_string( + "postgres://user:pass@host:5555/db?sslmode=require&otherparam=ignore") + assert_custom_params ci + + ci = PQ::ConnInfo.from_conninfo_string( + "postgres://user:pass@host:5555/db?sslmode=verify-full&sslcert=postgresql.crt&sslkey=postgresql.key&sslrootcert=root.crt") + assert_ssl_params ci + + ci = PQ::ConnInfo.from_conninfo_string( + "postgresql://user:pass@host:5555/db?sslmode=require") + assert_custom_params ci + end + + it "parses libpq style strings" do + ci = PQ::ConnInfo.from_conninfo_string( + "host=host dbname=db user=user password=pass port=5555 sslmode=require") + assert_custom_params ci + + ci = PQ::ConnInfo.from_conninfo_string( + "host=host dbname=db user=user password=pass port=5555 sslmode=verify-full sslcert=postgresql.crt sslkey=postgresql.key sslrootcert=root.crt") + assert_ssl_params ci + + ci = PQ::ConnInfo.from_conninfo_string("host=host") + ci.host.should eq("host") + + ci = PQ::ConnInfo.from_conninfo_string("") + assert_default_params ci + + expect_raises(ArgumentError) { + PQ::ConnInfo.from_conninfo_string("hosthost") + } + end +end diff --git a/lib/pg/spec/pq/param_spec.cr b/lib/pg/spec/pq/param_spec.cr new file mode 100644 index 00000000..1a5c2b5c --- /dev/null +++ b/lib/pg/spec/pq/param_spec.cr @@ -0,0 +1,23 @@ +require "spec" +require "../../src/pq/param" + +private def it_encodes_array(value, encoded) + it "encodes #{value.class}" do + PQ::Param.encode_array(value).should eq encoded + end +end + +describe PQ::Param do + describe "encoders" do + describe "#encode_array" do + it_encodes_array([] of String, "{}") + it_encodes_array([true, false, true], "{t,f,t}") + it_encodes_array(["t", "f", "t"], %({"t","f","t"})) + it_encodes_array([1, 2, 3], "{1,2,3}") + it_encodes_array([1.2, 3.4, 5.6], "{1.2,3.4,5.6}") + it_encodes_array([%(a), %(\\b~), %(c\\"d), %(\uFF8F)], %({"a","\\\\b~","c\\\\\\"d","\uFF8F"})) + it_encodes_array(["baz, bar"], %({"baz, bar"})) + it_encodes_array(["foo}"], %({"foo}"})) + end + end +end diff --git a/lib/pg/spec/spec_helper.cr b/lib/pg/spec/spec_helper.cr new file mode 100644 index 00000000..0848ee43 --- /dev/null +++ b/lib/pg/spec/spec_helper.cr @@ -0,0 +1,39 @@ +require "spec" +require "../src/pg" + +DB_URL = ENV["DATABASE_URL"]? || "postgres:///" +PG_DB = DB.open(DB_URL) + +def with_db + DB.open(DB_URL) do |db| + yield db + end +end + +def with_connection + DB.connect(DB_URL) do |conn| + yield conn + end +end + +def escape_literal(string) + with_connection &.escape_literal(string) +end + +def escape_identifier(string) + with_connection &.escape_identifier(string) +end + +module Helper + def self.db_version_gte(major, minor = 0, patch = 0) + ver = with_connection &.version + ver[:major] >= major && ver[:minor] >= minor && ver[:patch] >= patch + end +end + +def test_decode(name, query, expected, file = __FILE__, line = __LINE__) + it name, file, line do + value = PG_DB.query_one "select #{query}", &.read + value.should eq(expected), file, line + end +end diff --git a/lib/pg/src/pg.cr b/lib/pg/src/pg.cr new file mode 100644 index 00000000..79c3abad --- /dev/null +++ b/lib/pg/src/pg.cr @@ -0,0 +1,44 @@ +require "db" +require "./pg/*" + +module PG + # Establish a connection to the database + def self.connect(url) + DB.open(url) + end + + # Establish a special listen-only connection to the database. + # + # ``` + # PG.connect_listen(ENV["DATABASE_URL"], "foo", "bar") do |notification| + # pp notification.channel, notification.payload, notification.pid + # end + # ``` + def self.connect_listen(url, *channels : String, &blk : PQ::Notification ->) : ListenConnection + connect_listen(url, channels, &blk) + end + + # ditto + def self.connect_listen(url, channels : Enumerable(String), &blk : PQ::Notification ->) : ListenConnection + ListenConnection.new(url, channels, &blk) + end + + class ListenConnection + @conn : PG::Connection + + def self.new(url, *channels : String, &blk : PQ::Notification ->) + new(url, channels, &blk) + end + + def initialize(url, channels : Enumerable(String), &blk : PQ::Notification ->) + @conn = DB.connect(url).as(PG::Connection) + @conn.on_notification(&blk) + @conn.listen(channels) + end + + # Close the connection. + def close + @conn.close + end + end +end diff --git a/lib/pg/src/pg/connection.cr b/lib/pg/src/pg/connection.cr new file mode 100644 index 00000000..75b3cdda --- /dev/null +++ b/lib/pg/src/pg/connection.cr @@ -0,0 +1,67 @@ +require "../pq/*" + +module PG + class Connection < ::DB::Connection + protected getter connection + + def initialize(context) + super + @connection = uninitialized PQ::Connection + + begin + conn_info = PQ::ConnInfo.new(context.uri) + @connection = PQ::Connection.new(conn_info) + @connection.connect + rescue ex + raise DB::ConnectionRefused.new(cause: ex) + end + end + + def build_prepared_statement(query) : Statement + Statement.new(self, query) + end + + def build_unprepared_statement(query) : Statement + Statement.new(self, query) + end + + # Execute several statements. No results are returned. + def exec_all(query : String) : Nil + PQ::SimpleQuery.new(@connection, query) + nil + end + + # Set the callback block for notices and errors. + def on_notice(&on_notice_proc : PQ::Notice ->) + @connection.notice_handler = on_notice_proc + end + + # Set the callback block for notifications from Listen/Notify. + def on_notification(&on_notification_proc : PQ::Notification ->) + @connection.notification_handler = on_notification_proc + end + + protected def listen(channels : Enumerable(String)) + channels.each { |c| exec_all("LISTEN " + escape_identifier(c)) } + listen + end + + protected def listen + spawn { @connection.read_async_frame_loop } + end + + def version + vers = connection.server_parameters["server_version"].partition(' ').first.split('.').map(&.to_i) + {major: vers[0], minor: vers[1], patch: vers[2]? || 0} + end + + protected def do_close + super + + begin + @connection.close + rescue + end + end + end +end diff --git a/lib/pg/src/pg/decoder.cr b/lib/pg/src/pg/decoder.cr new file mode 100644 index 00000000..e43899f0 --- /dev/null +++ b/lib/pg/src/pg/decoder.cr @@ -0,0 +1,510 @@ +require "json" + +module PG + alias PGValue = String | Nil | Bool | Int32 | Float32 | Float64 | Time | JSON::Any | PG::Numeric + + # :nodoc: + module Decoders + module Decoder + abstract def decode(io, bytesize, oid) + abstract def oids : Array(Int32) + abstract def type + + macro def_oids(oids) + OIDS = {{oids}} + + def oids : Array(Int32) + OIDS + end + end + + def read(io, type) + io.read_bytes(type, IO::ByteFormat::NetworkEndian) + end + + def read_i16(io) + read(io, Int16) + end + + def read_i32(io) + read(io, Int32) + end + + def read_i64(io) + read(io, Int64) + end + + def read_u32(io) + read(io, UInt32) + end + + def read_u64(io) + read(io, UInt64) + end + + def read_f32(io) + read(io, Float32) + end + + def read_f64(io) + read(io, Float64) + end + end + + struct StringDecoder + include Decoder + + UUID_OID = 2950 + + def_oids [ + 19, # name (internal type) + 25, # text + 142, # xml + 705, # unknown + 1042, # blchar + 1043, # varchar + UUID_OID, # uuid + ] + + def decode(io, bytesize, oid) + if oid == UUID_OID + return decode_uuid(io, bytesize) + end + + String.new(bytesize) do |buffer| + io.read_fully(Slice.new(buffer, bytesize)) + {bytesize, 0} + end + end + + private def decode_uuid(io, bytesize) + bytes = uninitialized UInt8[6] + + String.new(36) do |buffer| + buffer[8] = buffer[13] = buffer[18] = buffer[23] = 45_u8 + + slice = bytes.to_slice[0, 4] + + io.read_fully(slice) + slice.hexstring(buffer + 0) + + slice = bytes.to_slice[0, 2] + + io.read_fully(slice) + slice.hexstring(buffer + 9) + + io.read_fully(slice) + slice.hexstring(buffer + 14) + + io.read_fully(slice) + slice.hexstring(buffer + 19) + + slice = bytes.to_slice + io.read_fully(slice) + slice.hexstring(buffer + 24) + + {36, 36} + end + end + + def type + String + end + end + + struct CharDecoder + include Decoder + + def_oids [ + 18, # "char" (internal type) + ] + + def decode(io, bytesize, oid) + # TODO: can be done without creating an intermediate string + String.new(bytesize) do |buffer| + io.read_fully(Slice.new(buffer, bytesize)) + {bytesize, 0} + end[0] + end + + def type + Char + end + end + + struct BoolDecoder + include Decoder + + OIDS = [ + 16, # bool + ] + + def decode(io, bytesize, oid) + case byte = io.read_byte + when 0 + false + when 1 + true + else + raise "bad boolean decode: #{byte}" + end + end + + def oids : Array(Int32) + OIDS + end + + def type + Bool + end + end + + struct Int16Decoder + include Decoder + + def_oids [ + 21, # int2 (smallint) + ] + + def decode(io, bytesize, oid) + read_i16(io) + end + + def type + Int16 + end + end + + struct Int32Decoder + include Decoder + + def_oids [ + 23, # int4 (integer) + 2206, # regtype + ] + + def decode(io, bytesize, oid) + read_i32(io) + end + + def type + Int32 + end + end + + struct Int64Decoder + include Decoder + + def_oids [ + 20, # int8 (bigint) + ] + + def decode(io, bytesize, oid) + read_u64(io).to_i64 + end + + def type + Int64 + end + end + + struct UIntDecoder + include Decoder + + def_oids [ + 26, # oid (internal type) + ] + + def decode(io, bytesize, oid) + read_u32(io) + end + + def type + UInt32 + end + end + + struct Float32Decoder + include Decoder + + def_oids [ + 700, # float4 + ] + + def decode(io, bytesize, oid) + read_f32(io) + end + + def type + Float32 + end + end + + struct Float64Decoder + include Decoder + + def_oids [ + 701, # float8 + ] + + def decode(io, bytesize, oid) + read_f64(io) + end + + def type + Float64 + end + end + + struct PointDecoder + include Decoder + + def_oids [ + 600, # point + ] + + def decode(io, bytesize, oid) + Geo::Point.new(read_f64(io), read_f64(io)) + end + + def type + Geo::Point + end + end + + struct PathDecoder + include Decoder + + def_oids [ + 602, # path + ] + + def decode(io, bytesize, oid) + byte = io.read_byte.not_nil! + closed = byte == 1_u8 + Geo::Path.new(PolygonDecoder.new.decode(io, bytesize - 1, oid).points, closed) + end + + def type + Geo::Path + end + end + + struct PolygonDecoder + include Decoder + + def_oids [ + 604, # polygon + ] + + def decode(io, bytesize, oid) + c = read_u32(io) + count = (pointerof(c).as(Int32*)).value + points = Array.new(count) do |i| + PointDecoder.new.decode(io, 16, oid) + end + Geo::Polygon.new(points) + end + + def type + Geo::Polygon + end + end + + struct BoxDecoder + include Decoder + + def_oids [ + 603, # box + ] + + def decode(io, bytesize, oid) + x2, y2, x1, y1 = read_f64(io), read_f64(io), read_f64(io), read_f64(io) + Geo::Box.new(x1, y1, x2, y2) + end + + def type + Geo::Box + end + end + + struct LineSegmentDecoder + include Decoder + + def_oids [ + 601, # lseg + ] + + def decode(io, bytesize, oid) + Geo::LineSegment.new(read_f64(io), read_f64(io), read_f64(io), read_f64(io)) + end + + def type + Geo::LineSegment + end + end + + struct LineDecoder + include Decoder + + def_oids [ + 628, # line + ] + + def decode(io, bytesize, oid) + Geo::Line.new(read_f64(io), read_f64(io), read_f64(io)) + end + + def type + Geo::Line + end + end + + struct CircleDecoder + include Decoder + + def_oids [ + 718, # circle + ] + + def decode(io, bytesize, oid) + Geo::Circle.new(read_f64(io), read_f64(io), read_f64(io)) + end + + def type + Geo::Circle + end + end + + struct JsonDecoder + include Decoder + + JSONB_OID = 3802 + + def_oids [ + 114, # json + JSONB_OID, # jsonb + ] + + def decode(io, bytesize, oid) + if oid == JSONB_OID + io.read_byte + bytesize -= 1 + end + + string = String.new(bytesize) do |buffer| + io.read_fully(Slice.new(buffer, bytesize)) + {bytesize, 0} + end + JSON.parse(string) + end + + def type + JSON::Any + end + end + + struct TimeDecoder + include Decoder + + DATE_OID = 1082 + JAN_1_2K = Time.utc(2000, 1, 1) + + def_oids [ + DATE_OID, # date + 1114, # timestamp + 1184, # timestamptz + ] + + def decode(io, bytesize, oid) + if oid == DATE_OID + v = read_i32(io) + JAN_1_2K + Time::Span.new(days: v, hours: 0, minutes: 0, seconds: 0) + else + v = read_i64(io) # microseconds + sec, m = v.divmod(1_000_000) + JAN_1_2K + Time::Span.new(seconds: sec, nanoseconds: m*1000) + end + end + + def type + Time + end + end + + struct ByteaDecoder + include Decoder + + def_oids [ + 17, # bytea + ] + + def decode(io, bytesize, oid) + slice = Bytes.new(bytesize) + io.read_fully(slice) + slice + end + + def type + Bytes + end + end + + struct NumericDecoder + include Decoder + + def_oids [ + 1700, # numeric + ] + + def decode(io, bytesize, oid) + ndigits = read_i16(io) + weight = read_i16(io) + sign = read_i16(io) + dscale = read_i16(io) + digits = (0...ndigits).map { |i| read_i16(io) } + PG::Numeric.new(ndigits, weight, sign, dscale, digits) + end + + def type + PG::Numeric + end + end + + @@decoders = Hash(Int32, PG::Decoders::Decoder).new(ByteaDecoder.new) + + def self.from_oid(oid) + @@decoders[oid] + end + + def self.register_decoder(decoder) + decoder.oids.each do |oid| + @@decoders[oid] = decoder + end + end + + # https://github.com/postgres/postgres/blob/master/src/include/catalog/pg_type.h + register_decoder BoolDecoder.new + register_decoder ByteaDecoder.new + register_decoder CharDecoder.new + register_decoder StringDecoder.new + register_decoder Int16Decoder.new + register_decoder Int32Decoder.new + register_decoder Int64Decoder.new + register_decoder UIntDecoder.new + register_decoder JsonDecoder.new + register_decoder Float32Decoder.new + register_decoder Float64Decoder.new + register_decoder TimeDecoder.new + register_decoder NumericDecoder.new + register_decoder PointDecoder.new + register_decoder LineSegmentDecoder.new + register_decoder PathDecoder.new + register_decoder BoxDecoder.new + register_decoder PolygonDecoder.new + register_decoder LineDecoder.new + register_decoder CircleDecoder.new + end +end + +require "./decoders/*" diff --git a/lib/pg/src/pg/decoders/array_decoder.cr b/lib/pg/src/pg/decoders/array_decoder.cr new file mode 100644 index 00000000..7da461f3 --- /dev/null +++ b/lib/pg/src/pg/decoders/array_decoder.cr @@ -0,0 +1,176 @@ +require "../numeric" + +module PG + module Decoders + # Generic Array decoder: decodes to a recursive array type + struct ArrayDecoder(T, D) + include Decoder + + getter oids : Array(Int32) + + def self.new(oid : Int32) + new([oid]) + end + + def initialize(@oids : Array(Int32)) + end + + def decode(io, bytesize, oid) + header = Decoders.decode_array_header(io) + + if header.dimensions == 0 + ([] of T).as(T) + elsif header.dimensions == 1 && header.dim_info.first[:lbound] == 1 + # allow casting down to unnested crystal arrays + build_simple_array(io, header.dim_info.first[:dim], header.oid).as(T) + else + if header.dim_info.any? { |di| di[:lbound] < 1 } + raise PG::RuntimeError.new("Only lower-bounds >= 1 are supported") + end + + # recursively build nested array + get_element(io, header.dim_info, header.oid).as(T) + end + end + + def build_simple_array(io, size, oid) + Array(T).new(size) { get_next(io, oid) } + end + + def get_element(io, dim_info, oid) + if dim_info.size == 1 + lbound = dim_info.first[:lbound] - 1 # in lower-bound is not 1 + Array(T).new(dim_info.first[:dim] + lbound) do |i| + i < lbound ? nil : get_next(io, oid) + end + else + Array(T).new(dim_info.first[:dim]) do |i| + get_element(io, dim_info[1..-1], oid) + end + end + end + + def get_next(io, oid) + bytesize = read_i32(io) + if bytesize == -1 + nil + else + D.new.decode(io, bytesize, oid) + end + end + + def type + T + end + end + + # Specific array decoder method: decodes to exactly Array(T). + # Used when invoking, for example `rs.read(Array(Int32))`. + def self.decode_array(io, bytesize, t : Array(T).class) forall T + header = decode_array_header(io) + + decoder = array_decoder(T) + unless decoder.oids.includes?(header.oid) + correct_decoder = Decoders.from_oid(header.oid) + + raise PG::RuntimeError.new("Can't read column of type Array(#{correct_decoder.type}) as Array(#{flatten_type(T)})") + end + + if header.dimensions == 0 + return [] of T + end + + decode_array_element(io, t, header.dim_info, decoder, header.oid) + end + + def self.decode_array_element(io, t : Array(T).class, dim_info, decoder, oid) forall T + size = dim_info.first[:dim] + rest = dim_info[1..-1] + + Array(T).new(size) do + decode_array_element(io, T, rest, decoder, oid) + end + end + + def self.decode_array_element(io, t : T.class, dim_info, decoder, oid) forall T + bytesize = read_i32(io) + if bytesize == -1 + {% if T.nilable? %} + nil + {% else %} + raise PG::RuntimeError.new("unexpected NULL") + {% end %} + else + decoder.decode(io, bytesize, oid) + end + end + + def self.array_decoder(t : Array(T).class) forall T + array_decoder(T) + end + + {% for type in %w(Bool Char Int16 Int32 String Int64 Float32 Float64 Numeric Time).map(&.id) %} + def self.array_decoder(t : {{type}}?.class) + {{type}}Decoder.new + end + + def self.array_decoder(t : {{type}}.class) + {{type}}Decoder.new + end + {% end %} + + def self.flatten_type(t : Array(T).class) forall T + flatten_type(T) + end + + def self.flatten_type(t : T?.class) forall T + T + end + + def self.flatten_type(t : T.class) forall T + T + end + + record ArrayHeader, + dimensions : Int32, + oid : Int32, + dim_info : Array({dim: Int32, lbound: Int32}) + + def self.decode_array_header(io) + dimensions = read_i32(io) + has_null = read_i32(io) == 1 # unused + oid = read_i32(io) # unused but in header + dim_info = Array({dim: Int32, lbound: Int32}).new(dimensions) do |i| + { + dim: read_i32(io), + lbound: read_i32(io), + } + end + + ArrayHeader.new(dimensions, oid, dim_info) + end + + def self.read_i32(io) + io.read_bytes(Int32, IO::ByteFormat::NetworkEndian) + end + end + + macro array_type(t, oid) + alias {{t}}Array = {{t}}? | Array({{t}}Array) + + module Decoders + register_decoder ArrayDecoder({{t}}Array, {{t}}Decoder).new({{oid}}) + end + end + + array_type Bool, 1000 + array_type Char, 1002 + array_type Int16, 1005 + array_type Int32, 1007 + array_type Int64, 1016 + array_type Float32, 1021 + array_type Float64, 1022 + array_type String, 1009 + array_type Numeric, 1231 + array_type Time, [1115, 1182] +end diff --git a/lib/pg/src/pg/driver.cr b/lib/pg/src/pg/driver.cr new file mode 100644 index 00000000..423bbafc --- /dev/null +++ b/lib/pg/src/pg/driver.cr @@ -0,0 +1,8 @@ +class PG::Driver < ::DB::Driver + def build_connection(context : ::DB::ConnectionContext) : Connection + Connection.new(context) + end +end + +DB.register_driver "postgres", PG::Driver +DB.register_driver "postgresql", PG::Driver diff --git a/lib/pg/src/pg/error.cr b/lib/pg/src/pg/error.cr new file mode 100644 index 00000000..488ce3d6 --- /dev/null +++ b/lib/pg/src/pg/error.cr @@ -0,0 +1,8 @@ +module PG + # :nodoc: + class Error < ::Exception + end + + class RuntimeError < Error + end +end diff --git a/lib/pg/src/pg/escape_helper.cr b/lib/pg/src/pg/escape_helper.cr new file mode 100644 index 00000000..791df3eb --- /dev/null +++ b/lib/pg/src/pg/escape_helper.cr @@ -0,0 +1,92 @@ +module PG + module EscapeHelper + extend self + + # `#escape_identifier` escapes a string for use as an SQL identifier, such + # as a table, column, or function name. This is useful when a user-supplied + # identifier might contain special characters that would otherwise not be + # interpreted as part of the identifier by the SQL parser, or when the + # identifier might contain upper case characters whose case should be + # preserved. + def escape_identifier(str) + escape str, true + end + + # `#escape_literal` escapes a string for use within an SQL command. This is + # useful when inserting data values as literal constants in SQL commands. + # Certain characters (such as quotes and backslashes) must be escaped to + # prevent them from being interpreted specially by the SQL parser. + # PQescapeLiteral performs this operation. + # + # Note that it is not necessary nor correct to do escaping when a data + # value is passed as a separate parameter in `#exec` + def escape_literal(str) + escape str, false + end + + # `#escape_literal` escapes binary data suitable for use with the BYTEA type. + def escape_literal(slice : Slice(UInt8)) + ssize = slice.size * 2 + 4 + String.new(ssize) do |buffer| + buffer[0] = '\''.ord.to_u8 + buffer[1] = '\\'.ord.to_u8 + buffer[2] = 'x'.ord.to_u8 + slice.hexstring(buffer + 3) + buffer[ssize - 1] = '\''.ord.to_u8 + {ssize, ssize} + end + end + + # reimplimentation of PQescapeInternal + # todo this should take into account server encoding if not utf8 + private def escape(str : String, as_ident : Bool) + num_quotes = 0 + num_backslashes = 0 + quote_char = as_ident ? '"' : '\'' + + # scan the string for characters that must be escaped + str.each_char do |char| + case char + when '\\' + num_backslashes += 1 + when quote_char + num_quotes += 1 + end + end + + literal_with_backslashes = (!as_ident && num_backslashes > 0) + + result_size = str.size + num_quotes + 2 + if literal_with_backslashes + result_size += num_backslashes + 2 + end + + String.build(result_size) do |build| + if literal_with_backslashes + build << ' ' << 'E' + end + build << quote_char + + if num_backslashes == num_quotes == 0 + str.each_char { |c| build << c } + else + str.each_char do |c| + case c + when quote_char + build << quote_char + when '\\' + build << '\\' unless as_ident + end + build << c + end + end + + build << quote_char + end + end + end + + class Connection + include EscapeHelper + end +end diff --git a/lib/pg/src/pg/geo.cr b/lib/pg/src/pg/geo.cr new file mode 100644 index 00000000..3f551bed --- /dev/null +++ b/lib/pg/src/pg/geo.cr @@ -0,0 +1,21 @@ +module PG::Geo + record Point, x : Float64, y : Float64 + record Line, a : Float64, b : Float64, c : Float64 + record Circle, x : Float64, y : Float64, radius : Float64 + record LineSegment, x1 : Float64, y1 : Float64, x2 : Float64, y2 : Float64 + record Box, x1 : Float64, y1 : Float64, x2 : Float64, y2 : Float64 + + struct Path + getter points + getter? closed + + def initialize(@points : Array(Point), @closed : Bool) + end + + def open? + !closed? + end + end + + record Polygon, points : Array(Point) +end diff --git a/lib/pg/src/pg/numeric.cr b/lib/pg/src/pg/numeric.cr new file mode 100644 index 00000000..19e3bcd7 --- /dev/null +++ b/lib/pg/src/pg/numeric.cr @@ -0,0 +1,130 @@ +module PG + # The Postgres numeric type has arbitrary precision, and can be NaN, "not a + # number". + # + # The default version of `Numeric` in this driver only has `#to_f` which + # provides approximate conversion. To get true arbitrary precision, there is + # an optional extension `pg_ext/big_rational`, however LibGMP must be + # installed. + struct Numeric + # :nodoc: + enum Sign + Pos = 0x0000 + Neg = 0x4000 + Nan = -0x4000 + end + + # size of digits array + getter ndigits : Int16 + + # location of decimal point in digits array + # can be negative for small numbers such as 0.0000001 + getter weight : Int16 + + # positive, negative, or nan + getter sign : Sign + + # number of decimal point digits shown + # 1.10 is and 1.100 would only differ here + getter dscale : Int16 + + # array of numbers from 0-10,000 representing the numeric + # (not an array of individual digits!) + getter digits : Array(Int16) + + def initialize(@ndigits : Int16, @weight : Int16, sign, @dscale : Int16, @digits : Array(Int16)) + @sign = Sign.from_value(sign) + end + + # Returns `true` if the numeric is not a number. + def nan? + sign == Sign::Nan + end + + # Returns `true` if the numeric is negative. + def neg? + sign == Sign::Neg + end + + # The approximate representation of the numeric as a 64-bit float. + # + # Very small and very large values may be inaccurate and precision will be + # lost. + # NaN returns `0.0`. + def to_f : Float64 + to_f64 + end + + # ditto + def to_f64 : Float64 + num = digits.reduce(0_u64) { |a, i| a*10_000_u64 + i.to_u64 } + den = 10_000_f64**(ndigits - 1 - weight) + quot = num.to_f64 / den.to_f64 + neg? ? -quot : quot + end + + def inspect(io : IO) + to_s(io) + end + + def to_s(io : IO) + if ndigits == 0 + if nan? + io << "NaN" + else + io << '0' + if dscale > 0 + io << '.' + dscale.times { io << '0' } + end + end + + return + end + + io << '-' if neg? + + pos = 0 + + if weight >= 0 + io << digits[0].to_s + pos += 1 + (1..weight).each do |idx| + pos += 1 + str = digits[idx]?.to_s + (4 - str.size).times { io << '0' } + io << str + end + end + + return if dscale <= 0 + + io << '0' if weight < 0 + io << '.' + + count = 0 + (-1 - weight).times do + io << "0000" + count += 4 + end + + (pos...ndigits).each do |idx| + str = digits[idx].to_s + + (4 - str.size).times do + io << '0' + count += 1 + end + + if idx == ndigits - 1 + remain = (dscale + str.size) % 4 + str = str[0...remain] unless remain == 0 + end + io << str + count += str.size + end + + (dscale - count).times { io << '0' } + end + end +end diff --git a/lib/pg/src/pg/result_set.cr b/lib/pg/src/pg/result_set.cr new file mode 100644 index 00000000..81adf1e8 --- /dev/null +++ b/lib/pg/src/pg/result_set.cr @@ -0,0 +1,168 @@ +class PG::ResultSet < ::DB::ResultSet + getter rows_affected + + def initialize(statement, @fields : Array(PQ::Field)?) + super(statement) + @column_index = -1 # The current column + @end = false # Did we read all the rows? + @rows_affected = 0_i64 + end + + protected def conn + statement.as(Statement).conn + end + + def move_next : Bool + return false if @end + + fields = @fields + + # `move_next` might be called before consuming all rows, + # in that case we need to skip columns + if fields && @column_index > -1 && @column_index < fields.size + while @column_index < fields.size + skip + end + end + + unless fields + @end = true + frame = conn.expect_frame PQ::Frame::CommandComplete | PQ::Frame::EmptyQueryResponse + if frame.is_a?(PQ::Frame::CommandComplete) + @rows_affected = frame.rows_affected + end + + conn.expect_frame PQ::Frame::ReadyForQuery + return false + end + + if conn.read_next_row_start + # We ignore these (redundant information) + conn.read_i32 # size + conn.read_i16 # ncols + @column_index = 0 + true + else + conn.expect_frame PQ::Frame::ReadyForQuery + @end = true + false + end + rescue IO::Error + raise DB::ConnectionLost.new(statement.connection) + rescue ex + @end = true + raise ex + end + + def column_count : Int32 + @fields.try(&.size) || 0 + end + + def column_name(index : Int32) : String + field(index).name + end + + def column_type(index : Int32) + decoder(index).type + end + + def read + col_bytesize = conn.read_i32 + if col_bytesize == -1 + @column_index += 1 + return nil + end + + safe_read(col_bytesize) do |io| + decoder.decode(io, col_bytesize, oid) + end + rescue IO::Error + raise DB::ConnectionLost.new(statement.connection) + end + + def read(t : Array(T).class) : Array(T) forall T + read_array(Array(T)) do + raise PG::RuntimeError.new("unexpected NULL, expecting to read #{t}") + end + end + + def read(t : Array(T)?.class) : Array(T)? forall T + read_array(Array(T)) do + return nil + end + end + + private def read_array(t : T.class) : T forall T + col_bytesize = conn.read_i32 + if col_bytesize == -1 + @column_index += 1 + yield + end + + safe_read(col_bytesize) do |io| + Decoders.decode_array(io, col_bytesize, T) + end + rescue IO::Error + raise DB::ConnectionLost.new(statement.connection) + end + + private def safe_read(col_bytesize) + sized_io = IO::Sized.new(conn.soc, col_bytesize) + + begin + yield sized_io + ensure + # An exception might happen while decoding the value: + # 1. Make sure to skip the column bytes + # 2. Make sure to increment the column index + conn.soc.skip(sized_io.read_remaining) if sized_io.read_remaining > 0 + @column_index += 1 + end + end + + private def field(index = @column_index) + @fields.not_nil![index] + end + + private def decoder(index = @column_index) + Decoders.from_oid(oid(index)) + end + + private def oid(index = @column_index) + field(index).type_oid + end + + private def skip + col_size = conn.read_i32 + conn.skip_bytes(col_size) if col_size != -1 + @column_index += 1 + rescue IO::Error + raise DB::ConnectionLost.new(statement.connection) + end + + protected def do_close + super + + # Nothing to do if all the rows were consumed + return if @end + + # Check if we didn't advance to the first row + if @column_index == -1 + return unless move_next + end + + fields = @fields + + loop do + # Skip remaining columns + while fields && @column_index < fields.size + skip + end + + break unless move_next + end + rescue DB::ConnectionLost + # if the connection is lost there is nothing to be + # done since the result set is no longer needed + end +end diff --git a/lib/pg/src/pg/statement.cr b/lib/pg/src/pg/statement.cr new file mode 100644 index 00000000..f873bee9 --- /dev/null +++ b/lib/pg/src/pg/statement.cr @@ -0,0 +1,44 @@ +class PG::Statement < ::DB::Statement + def initialize(connection, @sql : String) + super(connection) + end + + protected def conn + connection.as(Connection).connection + end + + protected def perform_query(args : Enumerable) : ResultSet + params = args.map { |arg| PQ::Param.encode(arg) } + conn = self.conn + conn.send_parse_message(@sql) + conn.send_bind_message params + conn.send_describe_portal_message + conn.send_execute_message + conn.send_sync_message + conn.expect_frame PQ::Frame::ParseComplete + conn.expect_frame PQ::Frame::BindComplete + frame = conn.read + case frame + when PQ::Frame::RowDescription + fields = frame.fields + when PQ::Frame::NoData + fields = nil + else + raise "expected RowDescription or NoData, got #{frame}" + end + ResultSet.new(self, fields) + rescue IO::Error + raise DB::ConnectionLost.new(connection) + end + + protected def perform_exec(args : Enumerable) : ::DB::ExecResult + result = perform_query(args) + result.each { } + ::DB::ExecResult.new( + rows_affected: result.rows_affected, + last_insert_id: 0_i64 # postgres doesn't support this + ) + rescue IO::Error + raise DB::ConnectionLost.new(connection) + end +end diff --git a/lib/pg/src/pg/version.cr b/lib/pg/src/pg/version.cr new file mode 100644 index 00000000..e5593468 --- /dev/null +++ b/lib/pg/src/pg/version.cr @@ -0,0 +1,3 @@ +module PG + VERSION = "0.18.1" +end diff --git a/lib/pg/src/pg_ext/big_rational.cr b/lib/pg/src/pg_ext/big_rational.cr new file mode 100644 index 00000000..30e4f63d --- /dev/null +++ b/lib/pg/src/pg_ext/big_rational.cr @@ -0,0 +1,27 @@ +require "big" + +module PG + struct Numeric + # Returns a BigRational representation of the numeric. This retains all + # precision, but requires LibGMP installed. + def to_big_r + return BigRational.new(0, 1) if nan? || ndigits == 0 + + ten_k = BigInt.new(10_000) + num = digits.reduce(BigInt.new(0)) { |a, i| a*ten_k + BigInt.new(i) } + den = ten_k**(ndigits - 1 - weight) + quot = BigRational.new(num, den) + neg? ? -quot : quot + end + end + + class ResultSet + def read(t : BigRational.class) + read(PG::Numeric).to_big_r + end + + def read(t : BigRational?.class) + read(PG::Numeric?).try &.to_big_r + end + end +end diff --git a/lib/pg/src/pq/connection.cr b/lib/pg/src/pq/connection.cr new file mode 100644 index 00000000..0831abb6 --- /dev/null +++ b/lib/pg/src/pq/connection.cr @@ -0,0 +1,466 @@ +require "uri" +require "digest/md5" +require "socket" +require "socket/tcp_socket" +require "socket/unix_socket" +require "openssl" +require "openssl/hmac" + +module PQ + record Notification, pid : Int32, channel : String, payload : String + + # :nodoc: + class Connection + getter soc : UNIXSocket | TCPSocket | OpenSSL::SSL::Socket::Client + getter server_parameters : Hash(String, String) + property notice_handler : Notice -> + property notification_handler : Notification -> + + def initialize(@conninfo : ConnInfo) + @mutex = Mutex.new + @server_parameters = Hash(String, String).new + @established = false + @notice_handler = Proc(Notice, Void).new { } + @notification_handler = Proc(Notification, Void).new { } + + begin + if @conninfo.host[0] == '/' + soc = UNIXSocket.new(@conninfo.host) + else + soc = TCPSocket.new(@conninfo.host, @conninfo.port) + end + soc.sync = false + rescue e + raise ConnectionError.new("Cannot establish connection", cause: e) + end + + @soc = soc + negotiate_ssl if @soc.is_a?(TCPSocket) + end + + private def negotiate_ssl + write_i32 8 + write_i32 80877103 + @soc.flush + serv_ssl = case c = @soc.read_char + when 'S' then true + when 'N' then false + else + raise ConnectionError.new( + "Unexpected SSL response from server: #{c.inspect}") + end + + if serv_ssl + ctx = OpenSSL::SSL::Context::Client.new + ctx.verify_mode = OpenSSL::SSL::VerifyMode::NONE # currently emulating sslmode 'require' not verify_ca or verify_full + if sslcert = @conninfo.sslcert + ctx.certificate_chain = sslcert + end + if sslkey = @conninfo.sslkey + ctx.private_key = sslkey + end + if sslrootcert = @conninfo.sslrootcert + ctx.ca_certificates = sslrootcert + end + @soc = OpenSSL::SSL::Socket::Client.new(@soc, context: ctx, sync_close: true) + end + + if @conninfo.sslmode == :require && !@soc.is_a?(OpenSSL::SSL::Socket::Client) + close + raise ConnectionError.new("sslmode=require and server did not establish SSL") + end + end + + def close + synchronize do + return if @soc.closed? + send_terminate_message + @soc.close + end + end + + def synchronize + @mutex.synchronize { yield } + end + + private def write_i32(i : Int32) + soc.write_bytes i, IO::ByteFormat::NetworkEndian + end + + private def write_i32(i) + write_i32 i.to_i32 + end + + private def write_i16(i : Int16) + soc.write_bytes i, IO::ByteFormat::NetworkEndian + end + + private def write_i16(i) + write_i16 i.to_i16 + end + + private def write_null + soc.write_byte 0_u8 + end + + private def write_byte(byte) + soc.write_byte byte + end + + private def write_chr(chr : Char) + soc.write_byte chr.ord.to_u8 + end + + def read_i32 + soc.read_bytes(Int32, IO::ByteFormat::NetworkEndian) + end + + def read_i16 + soc.read_bytes(Int16, IO::ByteFormat::NetworkEndian) + end + + def read_bytes(count) + data = Slice(UInt8).new(count) + soc.read_fully(data) + data + end + + def skip_bytes(count) + soc.skip(count) + end + + def startup(args) + len = args.reduce(0) { |acc, arg| acc + arg.size + 1 } + write_i32 len + 8 + 1 + write_i32 0x30000 + args.each { |arg| soc << arg << '\0' } + write_null + soc.flush + end + + def read_data_row + size = read_i32 + ncols = read_i16 + row = Array(Slice(UInt8)?).new(ncols.to_i32) do + col_size = read_i32 + if col_size == -1 + nil + else + read_bytes(col_size) + end + end + + yield row + end + + def read + read(soc.read_char) + end + + def read(frame_type) + frame = read_one_frame(frame_type) + handle_async_frames(frame) ? read : frame + end + + def read_async_frame_loop + loop do + break if @soc.closed? + begin + handle_async_frames(read_one_frame(soc.read_char)) + rescue e : Errno + e.errno == Errno::EBADF && @soc.closed? ? break : raise e + end + end + end + + private def read_one_frame(frame_type) + size = read_i32 + slice = read_bytes(size - 4) + Frame.new(frame_type.not_nil!, slice) # .tap { |f| p f } + end + + private def handle_async_frames(frame) + if frame.is_a?(Frame::ErrorResponse) + handle_error frame + true + elsif frame.is_a?(Frame::NotificationResponse) + handle_notification frame + true + elsif frame.is_a?(Frame::NoticeResponse) + handle_notice frame + true + elsif frame.is_a?(Frame::ParameterStatus) + handle_parameter frame + true + else + false + end + end + + private def handle_error(error_frame : Frame::ErrorResponse) + expect_frame Frame::ReadyForQuery if @established + notice_handler.call(error_frame.as_notice) + raise PQError.new(error_frame.fields) + end + + private def handle_notice(frame : Frame::NoticeResponse) + notice_handler.call(frame.as_notice) + end + + private def handle_notification(frame : Frame::NotificationResponse) + notification_handler.call(frame.as_notification) + end + + private def handle_parameter(frame : Frame::ParameterStatus) + @server_parameters[frame.key] = frame.value + case frame.key + when "client_encoding" + if frame.value.upcase != "UTF8" + raise ConnectionError.new( + "Only UTF8 is supported for client_encoding, got: #{frame.value.inspect}") + end + when "integer_datetimes" + if frame.value != "on" + raise ConnectionError.new( + "Only on is supported for integer_datetimes, got: #{frame.value.inspect}") + end + end + end + + def connect + startup_args = [ + "user", @conninfo.user, + "database", @conninfo.database, + "application_name", "crystal", + "client_encoding", "utf8", + ] + + startup startup_args + + auth_frame = expect_frame Frame::Authentication + handle_auth auth_frame + + loop do + case frame = read + when Frame::BackendKeyData + # do nothing + when Frame::ReadyForQuery + break + else + raise "Expected BackendKeyData or ReadyForQuery but was #{frame}" + end + end + + @established = true + end + + private def handle_auth(auth_frame) + case auth_frame.type + when Frame::Authentication::Type::OK + # no op + when Frame::Authentication::Type::CleartextPassword + raise "Cleartext auth is not supported" + when Frame::Authentication::Type::SASL + handle_auth_sasl auth_frame.body + when Frame::Authentication::Type::MD5Password + handle_auth_md5 auth_frame.body + else + raise ConnectionError.new( + "unsupported authentication method: #{auth_frame.type}" + ) + end + end + + struct SamlContext + property client_first_msg : String + property client_first_msg_size : Int32 + + def initialize(@password : String) + @client_nonce = Random::Secure.urlsafe_base64(18) + @client_first_msg = "n,,n=,r=#{@client_nonce}" + @client_first_msg_size = @client_first_msg.bytesize + end + + def generate_client_final_message(body) + server_first_msg = String.new(body) + params = server_first_msg.split(',') + r = params.find { |p| p[0] == 'r' }.not_nil![2..-1] + s = params.find { |p| p[0] == 's' }.not_nil![2..-1] + i = params.find { |p| p[0] == 'i' }.not_nil![2..-1].to_i + raise ConnectionError.new("SASL: scram server nonce does not start with client nonce") unless r.starts_with?(@client_nonce) + + client_final_msg_without_proof = "c=biws,r=#{r}" + salted_pass = OpenSSL::PKCS5.pbkdf2_hmac(@password, Base64.decode(s), i, algorithm: OpenSSL::Algorithm::SHA256, key_size: 32) + server_key = OpenSSL::HMAC.digest(:sha256, salted_pass, "Server Key") + client_key = OpenSSL::HMAC.digest(:sha256, salted_pass, "Client Key") + auth_msg = "n=,r=#{@client_nonce},#{server_first_msg},#{client_final_msg_without_proof}" + client_sig = OpenSSL::HMAC.digest(:sha256, sha256(client_key), auth_msg) + @server_sig = OpenSSL::HMAC.digest(:sha256, server_key, auth_msg) + proof = Base64.strict_encode Slice.new(32) { |i| client_key[i].as(UInt8) ^ client_sig[i].as(UInt8) } + "#{client_final_msg_without_proof},p=#{proof}" + end + + def verify_server_signature(server_message) + server_sig = Base64.strict_encode @server_sig.not_nil! + raise ConnectionError.new("server signature does not match") unless server_message[2..-1] == server_sig.to_slice + end + + private def sha256(key) + OpenSSL::Digest.new("SHA256").update(key).digest + end + end + + private def handle_auth_sasl(mechanism_list) + # it is possible in the future for postgres to send something other than + # SCRAM-SHA-265, but for now ignore the mechanism_list + + ctx = SamlContext.new(@conninfo.password || "") + + # send client-first-message + write_chr 'p' # SASLInitialResponse + write_i32 4 + 13 + 1 + 4 + ctx.client_first_msg_size + soc << "SCRAM-SHA-256" + write_null + write_i32 ctx.client_first_msg_size + soc << ctx.client_first_msg + soc.flush + + # receive server-first-message + continue = expect_frame Frame::Authentication + final_msg = ctx.generate_client_final_message(continue.body) + + # send client-final-message + write_chr 'p' + write_i32 4 + final_msg.bytesize + soc << final_msg + soc.flush + + # receive server-final-message + final = expect_frame Frame::Authentication + ctx.verify_server_signature(final.body) + # receive OK + expect_frame Frame::Authentication + end + + private def handle_auth_md5(salt) + inner = Digest::MD5.hexdigest("#{@conninfo.password}#{@conninfo.user}") + + pass = Digest::MD5.hexdigest do |ctx| + ctx.update(inner.to_unsafe, inner.bytesize.to_u32) + ctx.update(salt.to_unsafe, salt.bytesize.to_u32) + end + + send_password_message "md5#{pass}" + expect_frame Frame::Authentication + end + + def read_next_row_start + type = soc.read_char + + while type == 'N' + # NoticeResponse + frame = read_one_frame('N') + handle_async_frames(frame) + type = soc.read_char + end + + if type == 'D' + true + else + expect_frame Frame::CommandComplete, type + false + end + end + + def read_all_data_rows + type = soc.read_char + loop do + break unless type == 'D' + read_data_row { |row| yield row } + type = soc.read_char + end + expect_frame Frame::CommandComplete, type + end + + def expect_frame(frame_class, type = nil) + f = type ? read(type) : read + raise "Expected #{frame_class} but got #{f}" unless frame_class === f + frame_class.cast(f) + end + + def send_password_message(password) + write_chr 'p' + if password + write_i32 password.size + 4 + 1 + soc << password + else + write_i32 4 + 1 + end + write_null + soc.flush + end + + def send_query_message(query) + write_chr 'Q' + write_i32 query.bytesize + 4 + 1 + soc << query + write_null + soc.flush + end + + def send_parse_message(query) + write_chr 'P' + write_i32 query.bytesize + 4 + 1 + 2 + 1 + write_null # prepared statment name + soc << query + write_i16 0 # don't give any param types + write_null + end + + def send_bind_message(params) + nparams = params.size + total_size = params.reduce(0) do |acc, p| + acc + 4 + (p.size == -1 ? 0 : p.size) + end + + write_chr 'B' + write_i32 4 + 1 + 1 + 2 + (2*nparams) + 2 + total_size + 2 + 2 + write_null # unnamed destination portal + write_null # unnamed prepared statment + write_i16 nparams # number of params format codes to follow + params.each { |p| write_i16 p.format } + write_i16 nparams # number of params to follow + params.each do |p| + write_i32 p.size + p.slice.each { |byte| write_byte byte } + end + write_i16 1 # number of following return types (1 means apply next for all) + write_i16 1 # all results as binary + end + + def send_describe_portal_message + write_chr 'D' + write_i32 4 + 1 + 1 + write_chr 'P' + write_null + end + + def send_execute_message + write_chr 'E' + write_i32 4 + 1 + 4 + write_null # unnamed portal + write_i32 0 # unlimited maximum rows + end + + def send_sync_message + write_chr 'S' + write_i32 4 + soc.flush + end + + def send_terminate_message + write_chr 'X' + write_i32 4 + end + end +end diff --git a/lib/pg/src/pq/conninfo.cr b/lib/pg/src/pq/conninfo.cr new file mode 100644 index 00000000..e4abf99e --- /dev/null +++ b/lib/pg/src/pq/conninfo.cr @@ -0,0 +1,141 @@ +require "uri" +require "http" + +module PQ + struct ConnInfo + SOCKET_SEARCH = %w(/run/postgresql/.s.PGSQL.5432 /tmp/.s.PGSQL.5432 /var/run/postgresql/.s.PGSQL.5432) + + # The host. If starts with a / it is assumed to be a local Unix socket. + getter host : String + + # The port, defaults to 5432. It is ignored for local Unix sockets. + getter port : Int32 + + # The database name. + getter database : String + + # The user. + getter user : String + + # The password. Optional. + getter password : String? + + # The sslmode. Optional (:prefer is default). + getter sslmode : Symbol + + # The sslcert. Optional. + getter sslcert : String? + + # The sslkey. Optional. + getter sslkey : String? + + # The sslrootcert. Optional. + getter sslrootcert : String? + + # Create a new ConnInfo from all parts + def initialize(host : String? = nil, database : String? = nil, user : String? = nil, @password : String? = nil, port : Int | String? = 5432, sslmode : String | Symbol? = nil) + @host = default_host host + db = default_database database + @database = db.starts_with?('/') ? db[1..-1] : db + @user = default_user user + @port = (port || 5432).to_i + @sslmode = default_sslmode sslmode + end + + # Initialize with either "postgres://" urls or postgres "key=value" pairs + def self.from_conninfo_string(conninfo : String) + if conninfo.starts_with?("postgres://") || conninfo.starts_with?("postgresql://") + new(URI.parse(conninfo)) + else + return new if conninfo == "" + + args = Hash(String, String).new + conninfo.split(' ').each do |pair| + begin + k, v = pair.split('=') + args[k] = v + rescue IndexError + raise ArgumentError.new("invalid paramater: #{pair}") + end + end + new(args) + end + end + + # Initializes with a `URI` + def initialize(uri : URI) + initialize(uri.host, uri.path, uri.user, uri.password, uri.port, :prefer) + if q = uri.query + HTTP::Params.parse(q) do |key, value| + handle_sslparam(key, value) + end + end + end + + # Initialize with a `Hash` + # + # Valid keys match Postgres "conninfo" keys and are `"host"`, `"dbname"`, + # `"user"`, `"password"`, `"port"`, `"sslmode"`, `"sslcert"`, `"sslkey"` and `"sslrootcert"` + def initialize(params : Hash) + initialize(params["host"]?, params["dbname"]?, params["user"]?, + params["password"]?, params["port"]?, params["sslmode"]?) + params.each do |key, value| + handle_sslparam(key, value) + end + end + + private def handle_sslparam(key : String, value : String) + case key + when "sslmode" + @sslmode = default_sslmode value + when "sslcert" + @sslcert = value + when "sslkey" + @sslkey = value + when "sslrootcert" + @sslrootcert = value + end + end + + private def default_host(h) + return h if h && !h.blank? + + SOCKET_SEARCH.each do |s| + return s if File.exists?(s) + end + + "localhost" + end + + private def default_database(db) + if db && db != "/" + db + else + `whoami`.chomp + end + end + + private def default_user(u) + u || `whoami`.chomp + end + + private def default_sslmode(mode) + case mode + when nil, :prefer, "prefer" + :prefer + when :disable, "disable" + :disable + when :allow, "allow" + :allow + when :require, "require" + :require + when :"verify-ca", "verify-ca" + :"verify-ca" + when :"verify-full", "verify-full" + :"verify-full" + else + raise ArgumentError.new("sslmode #{mode} not supported") + end + end + end +end diff --git a/lib/pg/src/pq/error.cr b/lib/pg/src/pq/error.cr new file mode 100644 index 00000000..11eb24e8 --- /dev/null +++ b/lib/pg/src/pq/error.cr @@ -0,0 +1,17 @@ +module PQ + class ConnectionError < Exception + end + + class PQError < Exception + getter fields : Array(Frame::ErrorResponse::Field) + + def initialize(@fields) + super(field_message :message) + end + + def field_message(name) + field = fields.find { |f| f.name == name } + field.message if field + end + end +end diff --git a/lib/pg/src/pq/field.cr b/lib/pg/src/pq/field.cr new file mode 100644 index 00000000..c157e45a --- /dev/null +++ b/lib/pg/src/pq/field.cr @@ -0,0 +1,10 @@ +module PQ + class Field + getter name, type_oid + + def initialize(@name : String, @col_oid : Int32, @table_oid : Int16, + @type_oid : Int32, @type_size : Int16, @type_modifier : Int32, + @format : Int16) + end + end +end diff --git a/lib/pg/src/pq/frame.cr b/lib/pg/src/pq/frame.cr new file mode 100644 index 00000000..ae244ed8 --- /dev/null +++ b/lib/pg/src/pq/frame.cr @@ -0,0 +1,244 @@ +module PQ + # :nodoc: + abstract struct Frame + getter bytes + + def self.new(type : Char, bytes : Slice(UInt8)) + k = case type + when 'C' then CommandComplete + when 'Z' then ReadyForQuery + when '1' then ParseComplete + when '2' then BindComplete + when 'T' then RowDescription + when 'A' then NotificationResponse + when 'E' then ErrorResponse + when 'N' then NoticeResponse + when 'n' then NoData + when 'I' then EmptyQueryResponse + # when 'D' then DataRow + when 'S' then ParameterStatus + when 'K' then BackendKeyData + when 'R' then Authentication + end + k ? k.new(bytes) : Unknown.new(type, bytes) + end + + def initialize(bytes) + end + + private def find_next_string(pos, bytes) + start = pos + (bytes + start).each do |c| + break if c == 0 + pos += 1 + end + return pos + 1, String.new(bytes[start, pos - start]) + end + + private def i32(pos, bytes) : {Int32, Int32} + return pos + 4, (bytes[pos + 3].to_i32 << 0) | + (bytes[pos + 2].to_i32 << 8) | + (bytes[pos + 1].to_i32 << 16) | + (bytes[pos + 0].to_i32 << 24) + end + + private def i16(pos, bytes) : {Int32, Int16} + return pos + 2, (bytes[pos + 1].to_i16 << 0) | + (bytes[pos + 0].to_i16 << 8) + end + + struct Unknown + getter type : Char + getter bytes : Slice(UInt8) + + def initialize(@type, @bytes) + end + end + + struct Authentication < Frame + enum Type : Int32 + OK = 0 + KerberosV5 = 2 + CleartextPassword = 3 + MD5Password = 5 + SCMCredential = 6 + GSS = 7 + GSSContinue = 8 + SASL = 10 + SASLContinue = 11 + SASLFinal = 12 + end + + getter type : Type + getter body : Slice(UInt8) + + def initialize(bytes) + pos, t = i32(0, bytes) + @type = Type.from_value t + @body = bytes + pos + end + end + + struct ParameterStatus < Frame + getter key : String + getter value : String + + def initialize(bytes) + pos, @key = find_next_string(0, bytes) + @value = String.new(bytes[pos, bytes.size - pos - 1]) + end + end + + struct BackendKeyData < Frame + getter pid : Int32 + getter secret : Int32 + + def initialize(bytes) + pos = 0 + pos, @pid = i32(pos, bytes) + pos, @secret = i32(pos, bytes) + end + end + + struct ReadyForQuery < Frame + enum Status : UInt8 + Idle = 0x49 # I + Transaction = 0x54 # T + Error = 0x45 # E + end + + getter transaction_status : Status + + def initialize(bytes) + @transaction_status = Status.from_value bytes[0] + end + end + + abstract struct ErrorNoticeFrame < Frame + record Field, name : Symbol, message : String, code : UInt8 do + def inspect(io) + io << name << ": " << message + end + end + getter fields : Array(Field) + + def initialize(bytes) + @fields = Array(Field).new + pos = 0 + loop do + code = bytes[pos] + break if code == 0 + pos += 1 + pos, message = find_next_string(pos, bytes) + @fields << Field.new(name_from_code(code), message, code) + end + end + + def as_notice : Notice + Notice.new(fields) + end + + private def name_from_code(code) + case code + when 'S' then :severity + when 'C' then :code + when 'M' then :message + when 'D' then :detail + when 'H' then :hint + when 'P' then :position + when 'p' then :internal_position + when 'q' then :internal_query + when 'W' then :where + when 's' then :schema_name + when 't' then :table_name + when 'c' then :column_name + when 'd' then :datatype_name + when 'n' then :constraint_name + when 'F' then :file + when 'L' then :line + when 'R' then :routine + else :unknown + end + end + end + + struct NotificationResponse < Frame + getter pid : Int32 + getter channel : String + getter payload : String + + def initialize(bytes) + pos = 0 + pos, @pid = i32 pos, bytes + pos, @channel = find_next_string(pos, bytes) + pos, @payload = find_next_string(pos, bytes) + end + + def as_notification : PQ::Notification + PQ::Notification.new(pid, channel, payload) + end + end + + struct ErrorResponse < ErrorNoticeFrame + end + + struct NoticeResponse < ErrorNoticeFrame + end + + struct RowDescription < Frame + getter nfields : Int16 + getter fields : Array(Field) + + def initialize(bytes) + pos = 0 + pos, @nfields = i16 pos, bytes + @fields = Array(Field).new(@nfields.to_i32) do + pos, name = find_next_string(pos, bytes) + pos, col_oid = i32 pos, bytes + pos, table_oid = i16 pos, bytes + pos, type_oid = i32 pos, bytes + pos, type_size = i16 pos, bytes + pos, type_modifier = i32 pos, bytes + pos, format = i16 pos, bytes + Field.new(name, col_oid, table_oid, type_oid, type_size, type_modifier, format) + end + end + end + + # struct DataRow < Frame + # def initialize(bytes, &block : Int16, Slice(UInt8) ->) + # pos, nrows = i16 0, bytes + # nrows.times do |i| + # pos, size = i32 pos, bytes + # block.call(i, bytes[pos, size]) + # pos += size + # end + # end + # end + + struct NoData < Frame + end + + struct CommandComplete < Frame + def initialize(@bytes : Bytes) + end + + def command + String.new(@bytes[0, @bytes.size - 1]) + end + + def rows_affected + command.split.last.to_i64? || 0_i64 + end + end + + struct ParseComplete < Frame + end + + struct BindComplete < Frame + end + + struct EmptyQueryResponse < Frame + end + end +end diff --git a/lib/pg/src/pq/notice.cr b/lib/pg/src/pq/notice.cr new file mode 100644 index 00000000..31cc6313 --- /dev/null +++ b/lib/pg/src/pq/notice.cr @@ -0,0 +1,37 @@ +module PQ + # http://www.postgresql.org/docs/current/static/protocol-error-fields.html + struct Notice + getter fields : Array(PQ::Frame::ErrorNoticeFrame::Field) + getter severity : String + getter code : String + getter message : String + + def initialize(@fields) + severity = "" + code = "" + message = "" + + fields.each do |f| + case f.name + when :severity + severity = f.message + when :code + code = f.message + when :message + message = f.message + end + end + + @severity = severity + @code = code + @message = message + end + + def to_s(io : IO) + io << severity + io << ": " + io << message + io << '\n' + end + end +end diff --git a/lib/pg/src/pq/param.cr b/lib/pg/src/pq/param.cr new file mode 100644 index 00000000..61bca48b --- /dev/null +++ b/lib/pg/src/pq/param.cr @@ -0,0 +1,148 @@ +require "../pg/geo" + +module PQ + # :nodoc: + record Param, slice : Slice(UInt8), size : Int32, format : Int16 do + delegate to_unsafe, to: slice + + # Internal wrapper to represent an encoded parameter + + def self.encode(val : Nil) + binary Pointer(UInt8).null.to_slice(0), -1 + end + + def self.encode(val : Slice) + binary val, val.size + end + + def self.encode(val : Array) + text encode_array(val) + end + + def self.encode(val : Time) + text Time::Format::RFC_3339.format(val) + end + + def self.encode(val : PG::Geo::Point) + text "(#{val.x},#{val.y})" + end + + def self.encode(val : PG::Geo::Line) + text "{#{val.a},#{val.b},#{val.c}}" + end + + def self.encode(val : PG::Geo::Circle) + text "<(#{val.x},#{val.y}),#{val.radius}>" + end + + def self.encode(val : PG::Geo::LineSegment) + text "((#{val.x1},#{val.y1}),(#{val.x2},#{val.y2}))" + end + + def self.encode(val : PG::Geo::Box) + text "((#{val.x1},#{val.y1}),(#{val.x2},#{val.y2}))" + end + + def self.encode(val : PG::Geo::Path) + if val.closed? + encode_points "(", val.points, ")" + else + encode_points "[", val.points, "]" + end + end + + def self.encode(val : PG::Geo::Polygon) + encode_points "(", val.points, ")" + end + + private def self.encode_points(left, points, right) + string = String.build do |io| + io << left + points.each_with_index do |point, i| + io << "," if i > 0 + io << "(" << point.x << "," << point.y << ")" + end + io << right + end + + text string + end + + def self.encode(val) + text val.to_s + end + + def self.binary(slice, size) + new slice, size, 1_i16 + end + + def self.text(string : String) + text string.to_slice + end + + def self.text(slice : Bytes) + new slice, slice.size, 0_i16 + end + + def self.encode_array(array) + String.build(array.size + 2) do |io| + encode_array(io, array) + end + end + + def self.encode_array(io, value : Array) + io << "{" + value.join(",", io) do |item| + encode_array(io, item) + end + io << "}" + end + + def self.encode_array(io, value) + io << value + end + + def self.encode_array(io, value : Bool) + io << (value ? 't' : 'f') + end + + def self.encode_array(io, value : Bytes) + io << '"' + io << String.new(value).gsub(%("), %(\\")) + io << '"' + end + + def self.encode_array(io, value : String) + io << '"' + if value.ascii_only? + special_chars = {'"'.ord.to_u8, '\\'.ord.to_u8} + last_index = 0 + value.to_slice.each_with_index do |byte, index| + if special_chars.includes?(byte) + io.write value.unsafe_byte_slice(last_index, index - last_index) + last_index = index + io << '\\' + end + end + + io.write value.unsafe_byte_slice(last_index) + else + last_index = 0 + reader = Char::Reader.new(value) + while reader.has_next? + char = reader.current_char + if {'"', '\\'}.includes?(char) + io.write value.unsafe_byte_slice(last_index, reader.pos - last_index) + last_index = reader.pos + io << '\\' + end + reader.next_char + end + + io.write value.unsafe_byte_slice(last_index) + end + + io << '"' + end + end +end diff --git a/lib/pg/src/pq/query.cr b/lib/pg/src/pq/query.cr new file mode 100644 index 00000000..7e18a2eb --- /dev/null +++ b/lib/pg/src/pq/query.cr @@ -0,0 +1,57 @@ +module PQ + # :nodoc: + class ExtendedQuery + getter conn, query, params, fields + + def initialize(conn, query, params) + encoded_params = params.map { |v| Param.encode(v) } + initialize(conn, query, encoded_params) + end + + def initialize(@conn : Connection, @query : String, @params : Array(Param)) + conn.send_parse_message query + conn.send_bind_message params + conn.send_describe_portal_message + conn.send_execute_message + conn.send_sync_message + conn.expect_frame Frame::ParseComplete + conn.expect_frame Frame::BindComplete + + frame = conn.read + if frame.is_a?(Frame::RowDescription) + @fields = frame.fields + @has_data = true + elsif frame.is_a?(Frame::NoData) + @fields = [] of PQ::Field + conn.expect_frame Frame::CommandComplete | Frame::EmptyQueryResponse + conn.expect_frame Frame::ReadyForQuery + @has_data = false + else + raise "expected RowDescription or NoData, got #{frame}" + end + @got_data = false + end + + def get_data + raise "already read data" if @got_data + if @has_data + conn.read_all_data_rows { |row| yield row } + conn.expect_frame Frame::ReadyForQuery + end + @got_data = true + end + end + + # :nodoc: + class SimpleQuery + getter conn, query + + def initialize(@conn : Connection, @query : String) + conn.send_query_message(query) + + # read_all_data_rows { |row| yield row } + while !conn.read.is_a?(Frame::ReadyForQuery) + end + end + end +end diff --git a/lib/radix/.gitignore b/lib/radix/.gitignore new file mode 100644 index 00000000..591e49cd --- /dev/null +++ b/lib/radix/.gitignore @@ -0,0 +1,8 @@ +/doc/ +/lib/ +/.crystal/ +/.shards/ + +# Libraries don't need dependency lock +# Dependencies will be locked in application that uses them +/shard.lock diff --git a/lib/radix/.travis.yml b/lib/radix/.travis.yml new file mode 100644 index 00000000..33cd2975 --- /dev/null +++ b/lib/radix/.travis.yml @@ -0,0 +1,11 @@ +language: crystal +crystal: + - latest + - nightly +matrix: + allow_failures: + - crystal: nightly + +notifications: + email: + on_success: never diff --git a/lib/radix/CHANGELOG.md b/lib/radix/CHANGELOG.md new file mode 100644 index 00000000..81ddc75b --- /dev/null +++ b/lib/radix/CHANGELOG.md @@ -0,0 +1,102 @@ +# Change Log + +All notable changes to Radix project will be documented in this file. +This project aims to comply with [Semantic Versioning](http://semver.org/), +so please check *Changed* and *Removed* notes before upgrading. + +## [Unreleased] + +## [0.3.9] - 2019-01-02 +### Fixed +- Correct catch-all issue caused when paths differ [#26](https://github.com/luislavena/radix/pull/26) (@silasb) + +## [0.3.8] - 2017-03-12 +### Fixed +- Correct lookup issue caused by incorrect comparison of shared key [#21](https://github.com/luislavena/radix/issues/21) +- Improve support for non-ascii keys in a tree. + +## [0.3.7] - 2017-02-04 +### Fixed +- Correct prioritization of node's children using combination of kind and + priority, allowing partial shared keys to coexist and resolve lookup. + +## [0.3.6] - 2017-01-18 +### Fixed +- Correct lookup issue caused by similar priority between named paramter and + shared partial key [kemalcr/kemal#293](https://github.com/kemalcr/kemal/issues/293) + +## [0.3.5] - 2016-11-24 +### Fixed +- Correct lookup issue when dealing with catch all and shared partial key (@crisward) + +## [0.3.4] - 2016-11-12 +### Fixed +- Ensure catch all parameter can be used as optional globbing (@jwoertink) + +## [0.3.3] - 2016-11-12 [YANKED] +### Fixed +- Ensure catch all parameter can be used as optional globbing (@jwoertink) + +## [0.3.2] - 2016-11-05 +### Fixed +- Do not force adding paths with shared named parameter in an specific order (@jwoertink) +- Give proper name to `Radix::VERSION` spec when running in verbose mode. +- Ensure code samples in docs can be executed. + +## [0.3.1] - 2016-07-29 +### Added +- Introduce `Radix::VERSION` so library version can be used at runtime. + +## [0.3.0] - 2016-04-16 +### Fixed +- Improve forward compatibility with newer versions of the compiler by adding + missing types to solve type inference errors. + +### Changed +- `Radix::Tree` now requires the usage of a type which will be used as node's + payload. See [README](README.md) for details. + +## [0.2.1] - 2016-03-15 +### Fixed +- Correct `Result#key` incorrect inferred type. + +### Removed +- Attempt to use two named parameters at the same level will raise + `Radix::Tree::SharedKeyError` + +## [0.2.0] - 2016-03-15 [YANKED] +### Removed +- Attempt to use two named parameters at the same level will raise + `Radix::Tree::SharedKeyError` + +## [0.1.2] - 2016-03-10 +### Fixed +- No longer split named parameters that share same level (@alsm) + +### Changed +- Attempt to use two named parameters at same level will display a + deprecation warning. Future versions will raise `Radix::Tree::SharedKeyError` + +## [0.1.1] - 2016-02-29 +### Fixed +- Fix named parameter key names extraction. + +## [0.1.0] - 2016-01-24 +### Added +- Initial release based on code extracted from Beryl. + +[Unreleased]: https://github.com/luislavena/radix/compare/v0.3.9...HEAD +[0.3.9]: https://github.com/luislavena/radix/compare/v0.3.8...v0.3.9 +[0.3.8]: https://github.com/luislavena/radix/compare/v0.3.7...v0.3.8 +[0.3.7]: https://github.com/luislavena/radix/compare/v0.3.6...v0.3.7 +[0.3.6]: https://github.com/luislavena/radix/compare/v0.3.5...v0.3.6 +[0.3.5]: https://github.com/luislavena/radix/compare/v0.3.4...v0.3.5 +[0.3.4]: https://github.com/luislavena/radix/compare/v0.3.3...v0.3.4 +[0.3.3]: https://github.com/luislavena/radix/compare/v0.3.2...v0.3.3 +[0.3.2]: https://github.com/luislavena/radix/compare/v0.3.1...v0.3.2 +[0.3.1]: https://github.com/luislavena/radix/compare/v0.3.0...v0.3.1 +[0.3.0]: https://github.com/luislavena/radix/compare/v0.2.1...v0.3.0 +[0.2.1]: https://github.com/luislavena/radix/compare/v0.2.0...v0.2.1 +[0.2.0]: https://github.com/luislavena/radix/compare/v0.1.2...v0.2.0 +[0.1.2]: https://github.com/luislavena/radix/compare/v0.1.1...v0.1.2 +[0.1.1]: https://github.com/luislavena/radix/compare/v0.1.0...v0.1.1 diff --git a/lib/radix/LICENSE b/lib/radix/LICENSE new file mode 100644 index 00000000..15bbbc16 --- /dev/null +++ b/lib/radix/LICENSE @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2016 Luis Lavena + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/lib/radix/Makefile b/lib/radix/Makefile new file mode 100644 index 00000000..d5c06322 --- /dev/null +++ b/lib/radix/Makefile @@ -0,0 +1,18 @@ +CRYSTAL ?= crystal + +profile ?= ## Display profiling information after specs execution +verbose ?= ## Run specs in verbose mode + +SPEC_FLAGS := $(if $(profile),--profile )$(if $(verbose),--verbose ) + +.PHONY: default autospec spec + +default: spec + +# `autospec` task uses `watchexec` external dependency: +# https://github.com/mattgreen/watchexec +autospec: + watchexec --exts cr --watch spec --watch src --clear $(CRYSTAL) spec $(SPEC_FLAGS) + +spec: + $(CRYSTAL) spec $(SPEC_FLAGS) \ No newline at end of file diff --git a/lib/radix/README.md b/lib/radix/README.md new file mode 100644 index 00000000..0fe266de --- /dev/null +++ b/lib/radix/README.md @@ -0,0 +1,129 @@ +# Radix Tree + +[Radix tree](https://en.wikipedia.org/wiki/Radix_tree) implementation for +Crystal language + +[![Build Status](https://img.shields.io/travis/luislavena/radix/master.svg)](https://travis-ci.org/luislavena/radix) +[![Latest Release](https://img.shields.io/github/release/luislavena/radix.svg)](https://github.com/luislavena/radix/releases) + +## Installation + +Add this to your application's `shard.yml`: + +```yaml +dependencies: + radix: + github: luislavena/radix +``` + +## Usage + +### Building Trees + +You can associate a *payload* with each path added to the tree: + +```crystal +require "radix" + +tree = Radix::Tree(Symbol).new +tree.add "/products", :products +tree.add "/products/featured", :featured + +result = tree.find "/products/featured" + +if result.found? + puts result.payload # => :featured +end +``` + +The types allowed for payload are defined on Tree definition: + +```crystal +tree = Radix::Tree(Symbol).new + +# Good, since Symbol is allowed as payload +tree.add "/", :root + +# Compilation error, Int32 is not allowed +tree.add "/meaning-of-life", 42 +``` + +Can combine multiple types if needed: + +```crystal +tree = Radix::Tree(Int32 | String | Symbol).new + +tree.add "/", :root +tree.add "/meaning-of-life", 42 +tree.add "/hello", "world" +``` + +### Lookup and placeholders + +You can also extract values from placeholders (as named segments or globbing): + +```crystal +tree.add "/products/:id", :product + +result = tree.find "/products/1234" + +if result.found? + puts result.params["id"]? # => "1234" +end +``` + +Please see `Radix::Tree#add` documentation for more usage examples. + +## Caveats + +Pretty much all Radix implementations have their limitations and this project +is no exception. + +When designing and adding *paths* to a Tree, please consider that two different +named parameters cannot share the same level: + +```crystal +tree.add "/", :root +tree.add "/:post", :post +tree.add "/:category/:post", :category_post # => Radix::Tree::SharedKeyError +``` + +This is because different named parameters at the same level will result in +incorrect `params` when lookup is performed, and sometimes the value for +`post` or `category` parameters will not be stored as expected. + +To avoid this issue, usage of explicit keys that differentiate each path is +recommended. + +For example, following a good SEO practice will be consider `/:post` as +absolute permalink for the post and have a list of categories which links to +the permalinks of the posts under that category: + +```crystal +tree.add "/", :root +tree.add "/:post", :post # this is post permalink +tree.add "/categories", :categories # list of categories +tree.add "/categories/:category", :category # listing of posts under each category +``` + +## Implementation + +This project has been inspired and adapted from +[julienschmidt/httprouter](https://github.com/julienschmidt/httprouter) and +[spriet2000/vertx-http-router](https://github.com/spriet2000/vertx-http-router) +Go and Java implementations, respectively. + +Changes to logic and optimizations have been made to take advantage of +Crystal's features. + +## Contributing + +1. Fork it ( https://github.com/luislavena/radix/fork ) +2. Create your feature branch (`git checkout -b my-new-feature`) +3. Commit your changes (`git commit -am 'Add some feature'`) +4. Push to the branch (`git push origin my-new-feature`) +5. Create a new Pull Request + +## Contributors + +- [Luis Lavena](https://github.com/luislavena) - creator, maintainer diff --git a/lib/radix/shard.yml b/lib/radix/shard.yml new file mode 100644 index 00000000..70565cd6 --- /dev/null +++ b/lib/radix/shard.yml @@ -0,0 +1,7 @@ +name: radix +version: 0.3.9 + +authors: + - Luis Lavena + +license: MIT diff --git a/lib/radix/spec/radix/node_spec.cr b/lib/radix/spec/radix/node_spec.cr new file mode 100644 index 00000000..e43ffaa9 --- /dev/null +++ b/lib/radix/spec/radix/node_spec.cr @@ -0,0 +1,150 @@ +require "../spec_helper" + +module Radix + describe Node do + describe "#glob?" do + it "returns true when key contains a glob parameter (catch all)" do + node = Node(Nil).new("a") + node.glob?.should be_false + + node = Node(Nil).new("*filepath") + node.glob?.should be_true + end + end + + describe "#key=" do + it "accepts change of key after initialization" do + node = Node(Nil).new("abc") + node.key.should eq("abc") + + node.key = "xyz" + node.key.should eq("xyz") + end + + it "also changes kind when modified" do + node = Node(Nil).new("abc") + node.normal?.should be_true + + node.key = ":query" + node.normal?.should be_false + node.named?.should be_true + end + end + + describe "#named?" do + it "returns true when key contains a named parameter" do + node = Node(Nil).new("a") + node.named?.should be_false + + node = Node(Nil).new(":query") + node.named?.should be_true + end + end + + describe "#normal?" do + it "returns true when key does not contain named or glob parameters" do + node = Node(Nil).new("a") + node.normal?.should be_true + + node = Node(Nil).new(":query") + node.normal?.should be_false + + node = Node(Nil).new("*filepath") + node.normal?.should be_false + end + end + + describe "#payload" do + it "accepts any form of payload" do + node = Node.new("abc", :payload) + node.payload?.should be_truthy + node.payload.should eq(:payload) + + node = Node.new("abc", 1_000) + node.payload?.should be_truthy + node.payload.should eq(1_000) + end + + # This example focuses on the internal representation of `payload` + # as inferred from supplied types and default values. + # + # We cannot compare `typeof` against `property!` since it excludes `Nil` + # from the possible types. + it "makes optional to provide a payload" do + node = Node(Int32).new("abc") + node.payload?.should be_falsey + typeof(node.@payload).should eq(Int32 | Nil) + end + end + + describe "#priority" do + it "calculates it based on key length" do + node = Node(Nil).new("a") + node.priority.should eq(1) + + node = Node(Nil).new("abc") + node.priority.should eq(3) + end + + it "considers key length up until named parameter presence" do + node = Node(Nil).new("/posts/:id") + node.priority.should eq(7) + + node = Node(Nil).new("/u/:username") + node.priority.should eq(3) + end + + it "considers key length up until glob parameter presence" do + node = Node(Nil).new("/search/*query") + node.priority.should eq(8) + + node = Node(Nil).new("/*anything") + node.priority.should eq(1) + end + + it "changes when key changes" do + node = Node(Nil).new("a") + node.priority.should eq(1) + + node.key = "abc" + node.priority.should eq(3) + + node.key = "/src/*filepath" + node.priority.should eq(5) + + node.key = "/search/:query" + node.priority.should eq(8) + end + end + + describe "#sort!" do + it "orders children" do + root = Node(Int32).new("/") + node1 = Node(Int32).new("a", 1) + node2 = Node(Int32).new("bc", 2) + node3 = Node(Int32).new("def", 3) + + root.children.push(node1, node2, node3) + root.sort! + + root.children[0].should eq(node3) + root.children[1].should eq(node2) + root.children[2].should eq(node1) + end + + it "orders catch all and named parameters lower than normal nodes" do + root = Node(Int32).new("/") + node1 = Node(Int32).new("*filepath", 1) + node2 = Node(Int32).new("abc", 2) + node3 = Node(Int32).new(":query", 3) + + root.children.push(node1, node2, node3) + root.sort! + + root.children[0].should eq(node2) + root.children[1].should eq(node3) + root.children[2].should eq(node1) + end + end + end +end diff --git a/lib/radix/spec/radix/result_spec.cr b/lib/radix/spec/radix/result_spec.cr new file mode 100644 index 00000000..89ad2267 --- /dev/null +++ b/lib/radix/spec/radix/result_spec.cr @@ -0,0 +1,76 @@ +require "../spec_helper" + +module Radix + describe Result do + describe "#found?" do + context "a new instance" do + it "returns false when no payload is associated" do + result = Result(Nil).new + result.found?.should be_false + end + end + + context "with a payload" do + it "returns true" do + node = Node(Symbol).new("/", :root) + result = Result(Symbol).new + result.use node + + result.found?.should be_true + end + end + end + + describe "#key" do + context "a new instance" do + it "returns an empty key" do + result = Result(Nil).new + result.key.should eq("") + end + end + + context "given one used node" do + it "returns the node key" do + node = Node(Symbol).new("/", :root) + result = Result(Symbol).new + result.use node + + result.key.should eq("/") + end + end + + context "using multiple nodes" do + it "combines the node keys" do + node1 = Node(Symbol).new("/", :root) + node2 = Node(Symbol).new("about", :about) + result = Result(Symbol).new + result.use node1 + result.use node2 + + result.key.should eq("/about") + end + end + end + + describe "#use" do + it "uses the node payload" do + node = Node(Symbol).new("/", :root) + result = Result(Symbol).new + result.payload?.should be_falsey + + result.use node + result.payload?.should be_truthy + result.payload.should eq(node.payload) + end + + it "allow not to assign payload" do + node = Node(Symbol).new("/", :root) + result = Result(Symbol).new + result.payload?.should be_falsey + + result.use node, payload: false + result.payload?.should be_falsey + end + end + end +end diff --git a/lib/radix/spec/radix/tree_spec.cr b/lib/radix/spec/radix/tree_spec.cr new file mode 100644 index 00000000..760de45f --- /dev/null +++ b/lib/radix/spec/radix/tree_spec.cr @@ -0,0 +1,626 @@ +require "../spec_helper" + +# Silence deprecation warnings when running specs and allow +# capture them for inspection. +module Radix + class Tree(T) + @show_deprecations = false + @stderr : IO::Memory? + + def show_deprecations! + @show_deprecations = true + end + + private def deprecation(message) + if @show_deprecations + @stderr ||= IO::Memory.new + @stderr.not_nil!.puts message + end + end + end +end + +# Simple Payload class +record Payload + +module Radix + describe Tree do + context "a new instance" do + it "contains a root placeholder node" do + tree = Tree(Symbol).new + tree.root.should be_a(Node(Symbol)) + tree.root.payload?.should be_falsey + tree.root.placeholder?.should be_true + end + end + + describe "#add" do + context "on a new instance" do + it "replaces placeholder with new node" do + tree = Tree(Symbol).new + tree.add "/abc", :abc + tree.root.should be_a(Node(Symbol)) + tree.root.placeholder?.should be_false + tree.root.payload?.should be_truthy + tree.root.payload.should eq(:abc) + end + end + + context "shared root" do + it "inserts properly adjacent nodes" do + tree = Tree(Symbol).new + tree.add "/", :root + tree.add "/a", :a + tree.add "/bc", :bc + + # / (:root) + # +-bc (:bc) + # \-a (:a) + tree.root.children.size.should eq(2) + tree.root.children[0].key.should eq("bc") + tree.root.children[0].payload.should eq(:bc) + tree.root.children[1].key.should eq("a") + tree.root.children[1].payload.should eq(:a) + end + + it "inserts nodes with shared parent" do + tree = Tree(Symbol).new + tree.add "/", :root + tree.add "/abc", :abc + tree.add "/axyz", :axyz + + # / (:root) + # +-a + # +-xyz (:axyz) + # \-bc (:abc) + tree.root.children.size.should eq(1) + tree.root.children[0].key.should eq("a") + tree.root.children[0].children.size.should eq(2) + tree.root.children[0].children[0].key.should eq("xyz") + tree.root.children[0].children[1].key.should eq("bc") + end + + it "inserts multiple parent nodes" do + tree = Tree(Symbol).new + tree.add "/", :root + tree.add "/admin/users", :users + tree.add "/admin/products", :products + tree.add "/blog/tags", :tags + tree.add "/blog/articles", :articles + + # / (:root) + # +-admin/ + # | +-products (:products) + # | \-users (:users) + # | + # +-blog/ + # +-articles (:articles) + # \-tags (:tags) + tree.root.children.size.should eq(2) + tree.root.children[0].key.should eq("admin/") + tree.root.children[0].payload?.should be_falsey + tree.root.children[0].children[0].key.should eq("products") + tree.root.children[0].children[1].key.should eq("users") + tree.root.children[1].key.should eq("blog/") + tree.root.children[1].payload?.should be_falsey + tree.root.children[1].children[0].key.should eq("articles") + tree.root.children[1].children[0].payload?.should be_truthy + tree.root.children[1].children[1].key.should eq("tags") + tree.root.children[1].children[1].payload?.should be_truthy + end + + it "inserts multiple nodes with mixed parents" do + tree = Tree(Symbol).new + tree.add "/authorizations", :authorizations + tree.add "/authorizations/:id", :authorization + tree.add "/applications", :applications + tree.add "/events", :events + + # / + # +-events (:events) + # +-a + # +-uthorizations (:authorizations) + # | \-/:id (:authorization) + # \-pplications (:applications) + tree.root.children.size.should eq(2) + tree.root.children[1].key.should eq("a") + tree.root.children[1].children.size.should eq(2) + tree.root.children[1].children[0].payload.should eq(:authorizations) + tree.root.children[1].children[1].payload.should eq(:applications) + end + + it "supports insertion of mixed routes out of order" do + tree = Tree(Symbol).new + tree.add "/user/repos", :my_repos + tree.add "/users/:user/repos", :user_repos + tree.add "/users/:user", :user + tree.add "/user", :me + + # /user (:me) + # +-/repos (:my_repos) + # \-s/:user (:user) + # \-/repos (:user_repos) + tree.root.key.should eq("/user") + tree.root.payload?.should be_truthy + tree.root.payload.should eq(:me) + tree.root.children.size.should eq(2) + tree.root.children[0].key.should eq("/repos") + tree.root.children[1].key.should eq("s/:user") + tree.root.children[1].payload.should eq(:user) + tree.root.children[1].children[0].key.should eq("/repos") + end + end + + context "mixed payloads" do + it "allows node with different payloads" do + payload1 = Payload.new + payload2 = Payload.new + + tree = Tree(Payload | Symbol).new + tree.add "/", :root + tree.add "/a", payload1 + tree.add "/bc", payload2 + + # / (:root) + # +-bc (payload2) + # \-a (payload1) + tree.root.children.size.should eq(2) + tree.root.children[0].key.should eq("bc") + tree.root.children[0].payload.should eq(payload2) + tree.root.children[1].key.should eq("a") + tree.root.children[1].payload.should eq(payload1) + end + end + + context "dealing with unicode" do + it "inserts properly adjacent parent nodes" do + tree = Tree(Symbol).new + tree.add "/", :root + tree.add "/日本語", :japanese + tree.add "/素晴らしい", :amazing + + # / (:root) + # +-素晴らしい (:amazing) + # \-日本語 (:japanese) + tree.root.children.size.should eq(2) + tree.root.children[0].key.should eq("素晴らしい") + tree.root.children[1].key.should eq("日本語") + end + + it "inserts nodes with shared parent" do + tree = Tree(Symbol).new + tree.add "/", :root + tree.add "/日本語", :japanese + tree.add "/日本は難しい", :japanese_is_difficult + + # / (:root) + # \-日本語 (:japanese) + # \-日本は難しい (:japanese_is_difficult) + tree.root.children.size.should eq(1) + tree.root.children[0].key.should eq("日本") + tree.root.children[0].children.size.should eq(2) + tree.root.children[0].children[0].key.should eq("は難しい") + tree.root.children[0].children[1].key.should eq("語") + end + end + + context "dealing with duplicates" do + it "does not allow same path be defined twice" do + tree = Tree(Symbol).new + tree.add "/", :root + tree.add "/abc", :abc + + expect_raises Tree::DuplicateError do + tree.add "/", :other + end + + tree.root.children.size.should eq(1) + end + end + + context "dealing with catch all and named parameters" do + it "prioritizes nodes correctly" do + tree = Tree(Symbol).new + tree.add "/", :root + tree.add "/*filepath", :all + tree.add "/products", :products + tree.add "/products/:id", :product + tree.add "/products/:id/edit", :edit + tree.add "/products/featured", :featured + + # / (:all) + # +-products (:products) + # | \-/ + # | +-featured (:featured) + # | \-:id (:product) + # | \-/edit (:edit) + # \-*filepath (:all) + tree.root.children.size.should eq(2) + tree.root.children[0].key.should eq("products") + tree.root.children[0].children[0].key.should eq("/") + + nodes = tree.root.children[0].children[0].children + nodes.size.should eq(2) + nodes[0].key.should eq("featured") + nodes[1].key.should eq(":id") + nodes[1].children[0].key.should eq("/edit") + + tree.root.children[1].key.should eq("*filepath") + end + + it "does not split named parameters across shared key" do + tree = Tree(Symbol).new + tree.add "/", :root + tree.add "/:category", :category + tree.add "/:category/:subcategory", :subcategory + + # / (:root) + # +-:category (:category) + # \-/:subcategory (:subcategory) + tree.root.children.size.should eq(1) + tree.root.children[0].key.should eq(":category") + + # inner children + tree.root.children[0].children.size.should eq(1) + tree.root.children[0].children[0].key.should eq("/:subcategory") + end + + it "does allow same named parameter in different order of insertion" do + tree = Tree(Symbol).new + tree.add "/members/:id/edit", :member_edit + tree.add "/members/export", :members_export + tree.add "/members/:id/videos", :member_videos + + # /members/ + # +-export (:members_export) + # \-:id/ + # +-videos (:members_videos) + # \-edit (:members_edit) + tree.root.key.should eq("/members/") + tree.root.children.size.should eq(2) + + # first level children nodes + tree.root.children[0].key.should eq("export") + tree.root.children[1].key.should eq(":id/") + + # inner children + nodes = tree.root.children[1].children + nodes[0].key.should eq("videos") + nodes[1].key.should eq("edit") + end + + it "does not allow different named parameters sharing same level" do + tree = Tree(Symbol).new + tree.add "/", :root + tree.add "/:post", :post + + expect_raises Tree::SharedKeyError do + tree.add "/:category/:post", :category_post + end + end + end + end + + describe "#find" do + context "a single node" do + it "does not find when using different path" do + tree = Tree(Symbol).new + tree.add "/about", :about + + result = tree.find "/products" + result.found?.should be_false + end + + it "finds when key and path matches" do + tree = Tree(Symbol).new + tree.add "/about", :about + + result = tree.find "/about" + result.found?.should be_true + result.key.should eq("/about") + result.payload?.should be_truthy + result.payload.should eq(:about) + end + + it "finds when path contains trailing slash" do + tree = Tree(Symbol).new + tree.add "/about", :about + + result = tree.find "/about/" + result.found?.should be_true + result.key.should eq("/about") + end + + it "finds when key contains trailing slash" do + tree = Tree(Symbol).new + tree.add "/about/", :about + + result = tree.find "/about" + result.found?.should be_true + result.key.should eq("/about/") + result.payload.should eq(:about) + end + end + + context "nodes with shared parent" do + it "finds matching path" do + tree = Tree(Symbol).new + tree.add "/", :root + tree.add "/abc", :abc + tree.add "/axyz", :axyz + + result = tree.find("/abc") + result.found?.should be_true + result.key.should eq("/abc") + result.payload.should eq(:abc) + end + + it "finds matching path across separator" do + tree = Tree(Symbol).new + tree.add "/products", :products + tree.add "/product/new", :product_new + + result = tree.find("/products") + result.found?.should be_true + result.key.should eq("/products") + result.payload.should eq(:products) + end + + it "finds matching path across parents" do + tree = Tree(Symbol).new + tree.add "/", :root + tree.add "/admin/users", :users + tree.add "/admin/products", :products + tree.add "/blog/tags", :tags + tree.add "/blog/articles", :articles + + result = tree.find("/blog/tags/") + result.found?.should be_true + result.key.should eq("/blog/tags") + result.payload.should eq(:tags) + end + end + + context "unicode nodes with shared parent" do + it "finds matching path" do + tree = Tree(Symbol).new + tree.add "/", :root + tree.add "/日本語", :japanese + tree.add "/日本日本語は難しい", :japanese_is_difficult + + result = tree.find("/日本日本語は難しい/") + result.found?.should be_true + result.key.should eq("/日本日本語は難しい") + end + end + + context "dealing with catch all" do + it "finds matching path" do + tree = Tree(Symbol).new + tree.add "/", :root + tree.add "/*filepath", :all + tree.add "/about", :about + + result = tree.find("/src/file.png") + result.found?.should be_true + result.key.should eq("/*filepath") + result.payload.should eq(:all) + end + + it "returns catch all in parameters" do + tree = Tree(Symbol).new + tree.add "/", :root + tree.add "/*filepath", :all + tree.add "/about", :about + + result = tree.find("/src/file.png") + result.found?.should be_true + result.params.has_key?("filepath").should be_true + result.params["filepath"].should eq("src/file.png") + end + + it "returns optional catch all after slash" do + tree = Tree(Symbol).new + tree.add "/", :root + tree.add "/search/*extra", :extra + + result = tree.find("/search") + result.found?.should be_true + result.key.should eq("/search/*extra") + result.params.has_key?("extra").should be_true + result.params["extra"].empty?.should be_true + end + + it "returns optional catch all by globbing" do + tree = Tree(Symbol).new + tree.add "/members*trailing", :members_catch_all + + result = tree.find("/members") + result.found?.should be_true + result.key.should eq("/members*trailing") + result.params.has_key?("trailing").should be_true + result.params["trailing"].empty?.should be_true + end + + it "does not find when catch all is not full match" do + tree = Tree(Symbol).new + tree.add "/", :root + tree.add "/search/public/*query", :search + + result = tree.find("/search") + result.found?.should be_false + end + + it "does not find when path search has been exhausted" do + tree = Tree(Symbol).new + tree.add "/members/*trailing", :members_catch_all + + result = tree.find("/members2") + result.found?.should be_false + end + + it "does prefer specific path over catch all if both are present" do + tree = Tree(Symbol).new + tree.add "/members", :members + tree.add "/members*trailing", :members_catch_all + + result = tree.find("/members") + result.found?.should be_true + result.key.should eq("/members") + end + + it "does prefer catch all over specific key with partially shared key" do + tree = Tree(Symbol).new + tree.add "/orders/*anything", :orders_catch_all + tree.add "/orders/closed", :closed_orders + + result = tree.find("/orders/cancelled") + result.found?.should be_true + result.key.should eq("/orders/*anything") + result.params.has_key?("anything").should be_true + result.params["anything"].should eq("cancelled") + end + end + + context "dealing with named parameters" do + it "finds matching path" do + tree = Tree(Symbol).new + tree.add "/", :root + tree.add "/products", :products + tree.add "/products/:id", :product + tree.add "/products/:id/edit", :edit + + result = tree.find("/products/10") + result.found?.should be_true + result.key.should eq("/products/:id") + result.payload.should eq(:product) + end + + it "does not find partial matching path" do + tree = Tree(Symbol).new + tree.add "/", :root + tree.add "/products", :products + tree.add "/products/:id/edit", :edit + + result = tree.find("/products/10") + result.found?.should be_false + end + + it "returns named parameters in result" do + tree = Tree(Symbol).new + tree.add "/", :root + tree.add "/products", :products + tree.add "/products/:id", :product + tree.add "/products/:id/edit", :edit + + result = tree.find("/products/10/edit") + result.found?.should be_true + result.params.has_key?("id").should be_true + result.params["id"].should eq("10") + end + + it "returns unicode values in parameters" do + tree = Tree(Symbol).new + tree.add "/", :root + tree.add "/language/:name", :language + tree.add "/language/:name/about", :about + + result = tree.find("/language/日本語") + result.found?.should be_true + result.params.has_key?("name").should be_true + result.params["name"].should eq("日本語") + end + + it "does prefer specific path over named parameters one if both are present" do + tree = Tree(Symbol).new + tree.add "/tag-edit/:tag", :edit_tag + tree.add "/tag-edit2", :alternate_tag_edit + + result = tree.find("/tag-edit2") + result.found?.should be_true + result.key.should eq("/tag-edit2") + end + + it "does prefer named parameter over specific key with partially shared key" do + tree = Tree(Symbol).new + tree.add "/orders/:id", :specific_order + tree.add "/orders/closed", :closed_orders + + result = tree.find("/orders/10") + result.found?.should be_true + result.key.should eq("/orders/:id") + result.params.has_key?("id").should be_true + result.params["id"].should eq("10") + end + end + + context "dealing with multiple named parameters" do + it "finds matching path" do + tree = Tree(Symbol).new + tree.add "/", :root + tree.add "/:section/:page", :static_page + + result = tree.find("/about/shipping") + result.found?.should be_true + result.key.should eq("/:section/:page") + result.payload.should eq(:static_page) + end + + it "returns named parameters in result" do + tree = Tree(Symbol).new + tree.add "/", :root + tree.add "/:section/:page", :static_page + + result = tree.find("/about/shipping") + result.found?.should be_true + + result.params.has_key?("section").should be_true + result.params["section"].should eq("about") + + result.params.has_key?("page").should be_true + result.params["page"].should eq("shipping") + end + end + + context "dealing with both catch all and named parameters" do + it "finds matching path" do + tree = Tree(Symbol).new + tree.add "/", :root + tree.add "/*filepath", :all + tree.add "/products", :products + tree.add "/products/:id", :product + tree.add "/products/:id/edit", :edit + tree.add "/products/featured", :featured + + result = tree.find("/products/1000") + result.found?.should be_true + result.key.should eq("/products/:id") + result.payload.should eq(:product) + + result = tree.find("/admin/articles") + result.found?.should be_true + result.key.should eq("/*filepath") + result.params["filepath"].should eq("admin/articles") + + result = tree.find("/products/featured") + result.found?.should be_true + result.key.should eq("/products/featured") + result.payload.should eq(:featured) + end + end + + context "dealing with named parameters and shared key" do + it "finds matching path" do + tree = Tree(Symbol).new + tree.add "/one/:id", :one + tree.add "/one-longer/:id", :two + + result = tree.find "/one-longer/10" + result.found?.should be_true + result.key.should eq("/one-longer/:id") + result.params["id"].should eq("10") + end + end + end + end +end diff --git a/lib/radix/spec/radix/version_spec.cr b/lib/radix/spec/radix/version_spec.cr new file mode 100644 index 00000000..30ddc084 --- /dev/null +++ b/lib/radix/spec/radix/version_spec.cr @@ -0,0 +1,12 @@ +require "../spec_helper" +require "yaml" + +describe "Radix::VERSION" do + it "matches version defined in shard.yml" do + contents = File.read(File.expand_path("../../../shard.yml", __FILE__)) + meta = YAML.parse(contents) + + meta["version"]?.should_not be_falsey + Radix::VERSION.should eq(meta["version"].as_s) + end +end diff --git a/lib/radix/spec/spec_helper.cr b/lib/radix/spec/spec_helper.cr new file mode 100644 index 00000000..6fa0c75d --- /dev/null +++ b/lib/radix/spec/spec_helper.cr @@ -0,0 +1,2 @@ +require "spec" +require "../src/radix" diff --git a/lib/radix/src/radix.cr b/lib/radix/src/radix.cr new file mode 100644 index 00000000..97ff540e --- /dev/null +++ b/lib/radix/src/radix.cr @@ -0,0 +1,2 @@ +require "./radix/tree" +require "./radix/version" diff --git a/lib/radix/src/radix/node.cr b/lib/radix/src/radix/node.cr new file mode 100644 index 00000000..a87d39eb --- /dev/null +++ b/lib/radix/src/radix/node.cr @@ -0,0 +1,207 @@ +module Radix + # A Node represents one element in the structure of a [Radix tree](https://en.wikipedia.org/wiki/Radix_tree) + # + # Carries a *payload* and might also contain references to other nodes + # down in the organization inside *children*. + # + # Each node also carries identification in relation to the kind of key it + # contains, which helps with characteristics of the node like named + # parameters or catch all kind (globbing). + # + # Is not expected direct usage of a node but instead manipulation via + # methods within `Tree`. + class Node(T) + include Comparable(self) + + # :nodoc: + enum Kind : UInt8 + Normal + Named + Glob + end + + getter key + getter? placeholder + property children = [] of Node(T) + property! payload : T | Nil + + # :nodoc: + protected getter kind = Kind::Normal + + # Returns the priority of the Node based on it's *key* + # + # This value will be directly associated to the key size up until a + # special elements is found. + # + # ``` + # Radix::Node(Nil).new("a").priority + # # => 1 + # + # Radix::Node(Nil).new("abc").priority + # # => 3 + # + # Radix::Node(Nil).new("/src/*filepath").priority + # # => 5 + # + # Radix::Node(Nil).new("/search/:query").priority + # # => 8 + # ``` + getter priority : Int32 + + # Instantiate a Node + # + # - *key* - A `String` that represents this node. + # - *payload* - An optional payload for this node. + # + # When *payload* is not supplied, ensure the type of the node is provided + # instead: + # + # ``` + # # Good, node type is inferred from payload (Symbol) + # node = Radix::Node.new("/", :root) + # + # # Good, node type is now Int32 but payload is optional + # node = Radix::Node(Int32).new("/") + # + # # Error, node type cannot be inferred (compiler error) + # node = Radix::Node.new("/") + # ``` + def initialize(@key : String, @payload : T? = nil, @placeholder = false) + @priority = compute_priority + end + + # Compares this node against *other*, returning `-1`, `0` or `1` depending + # on whether this node differentiates from *other*. + # + # Comparison is done combining node's `kind` and `priority`. Nodes of + # same kind are compared by priority. Nodes of different kind are + # ranked. + # + # ### Normal nodes + # + # ``` + # node1 = Radix::Node(Nil).new("a") # normal + # node2 = Radix::Node(Nil).new("bc") # normal + # node1 <=> node2 # => 1 + # ``` + # + # ### Normal vs named or glob nodes + # + # ``` + # node1 = Radix::Node(Nil).new("a") # normal + # node2 = Radix::Node(Nil).new(":query") # named + # node3 = Radix::Node(Nil).new("*filepath") # glob + # node1 <=> node2 # => -1 + # node1 <=> node3 # => -1 + # ``` + # + # ### Named vs glob nodes + # + # ``` + # node1 = Radix::Node(Nil).new(":query") # named + # node2 = Radix::Node(Nil).new("*filepath") # glob + # node1 <=> node2 # => -1 + # ``` + def <=>(other : self) + result = kind <=> other.kind + return result if result != 0 + + other.priority <=> priority + end + + # Returns `true` if the node key contains a glob parameter in it + # (catch all) + # + # ``` + # node = Radix::Node(Nil).new("*filepath") + # node.glob? # => true + # + # node = Radix::Node(Nil).new("abc") + # node.glob? # => false + # ``` + def glob? + kind.glob? + end + + # Changes current *key* + # + # ``` + # node = Radix::Node(Nil).new("a") + # node.key + # # => "a" + # + # node.key = "b" + # node.key + # # => "b" + # ``` + # + # This will also result in change of node's `priority` + # + # ``` + # node = Radix::Node(Nil).new("a") + # node.priority + # # => 1 + # + # node.key = "abcdef" + # node.priority + # # => 6 + # ``` + def key=(@key) + # reset kind on change of key + @kind = Kind::Normal + @priority = compute_priority + end + + # Returns `true` if the node key contains a named parameter in it + # + # ``` + # node = Radix::Node(Nil).new(":query") + # node.named? # => true + # + # node = Radix::Node(Nil).new("abc") + # node.named? # => false + # ``` + def named? + kind.named? + end + + # Returns `true` if the node key does not contain an special parameter + # (named or glob) + # + # ``` + # node = Radix::Node(Nil).new("a") + # node.normal? # => true + # + # node = Radix::Node(Nil).new(":query") + # node.normal? # => false + # ``` + def normal? + kind.normal? + end + + # :nodoc: + private def compute_priority + reader = Char::Reader.new(@key) + + while reader.has_next? + case reader.current_char + when '*' + @kind = Kind::Glob + break + when ':' + @kind = Kind::Named + break + else + reader.next_char + end + end + + reader.pos + end + + # :nodoc: + protected def sort! + @children.sort! + end + end +end diff --git a/lib/radix/src/radix/result.cr b/lib/radix/src/radix/result.cr new file mode 100644 index 00000000..ad0a0bb5 --- /dev/null +++ b/lib/radix/src/radix/result.cr @@ -0,0 +1,88 @@ +require "./node" + +module Radix + # A Result is the comulative output of walking our [Radix tree](https://en.wikipedia.org/wiki/Radix_tree) + # `Radix::Tree` implementation. + # + # It provides helpers to retrieve the information obtained from walking + # our tree using `Radix::Tree#find` + # + # This information can be used to perform actions in case of the *path* + # that was looked on the Tree was found. + # + # A Result is also used recursively by `Radix::Tree#find` when collecting + # extra information like *params*. + class Result(T) + @key : String? + + getter params + getter! payload : T? + + # :nodoc: + def initialize + @nodes = [] of Node(T) + @params = {} of String => String + end + + # Returns whatever a *payload* was found by `Tree#find` and is part of + # the result. + # + # ``` + # result = Radix::Result(Symbol).new + # result.found? + # # => false + # + # root = Radix::Node(Symbol).new("/", :root) + # result.use(root) + # result.found? + # # => true + # ``` + def found? + payload? ? true : false + end + + # Returns a String built based on the nodes used in the result + # + # ``` + # node1 = Radix::Node(Symbol).new("/", :root) + # node2 = Radix::Node(Symbol).new("about", :about) + # + # result = Radix::Result(Symbol).new + # result.use node1 + # result.use node2 + # + # result.key + # # => "/about" + # ``` + # + # When no node has been used, returns an empty String. + # + # ``` + # result = Radix::Result(Nil).new + # result.key + # # => "" + # ``` + def key + @key ||= begin + String.build { |io| + @nodes.each do |node| + io << node.key + end + } + end + end + + # Adjust result information by using the details of the given `Node`. + # + # * Collect `Node` for future references. + # * Use *payload* if present. + def use(node : Node(T), payload = true) + # collect nodes + @nodes << node + + if payload && node.payload? + @payload = node.payload + end + end + end +end diff --git a/lib/radix/src/radix/tree.cr b/lib/radix/src/radix/tree.cr new file mode 100644 index 00000000..d6f4a969 --- /dev/null +++ b/lib/radix/src/radix/tree.cr @@ -0,0 +1,472 @@ +require "./node" +require "./result" + +module Radix + # A [Radix tree](https://en.wikipedia.org/wiki/Radix_tree) implementation. + # + # It allows insertion of *path* elements that will be organized inside + # the tree aiming to provide fast retrieval options. + # + # Each inserted *path* will be represented by a `Node` or segmented and + # distributed within the `Tree`. + # + # You can associate a *payload* at insertion which will be return back + # at retrieval time. + class Tree(T) + # :nodoc: + class DuplicateError < Exception + def initialize(path) + super("Duplicate trail found '#{path}'") + end + end + + # :nodoc: + class SharedKeyError < Exception + def initialize(new_key, existing_key) + super("Tried to place key '#{new_key}' at same level as '#{existing_key}'") + end + end + + # Returns the root `Node` element of the Tree. + # + # On a new tree instance, this will be a placeholder. + getter root : Node(T) + + def initialize + @root = Node(T).new("", placeholder: true) + end + + # Inserts given *path* into the Tree + # + # * *path* - An `String` representing the pattern to be inserted. + # * *payload* - Required associated element for this path. + # + # If no previous elements existed in the Tree, this will replace the + # defined placeholder. + # + # ``` + # tree = Radix::Tree(Symbol).new + # + # # / (:root) + # tree.add "/", :root + # + # # / (:root) + # # \-abc (:abc) + # tree.add "/abc", :abc + # + # # / (:root) + # # \-abc (:abc) + # # \-xyz (:xyz) + # tree.add "/abcxyz", :xyz + # ``` + # + # Nodes inside the tree will be adjusted to accommodate the different + # segments of the given *path*. + # + # ``` + # tree = Radix::Tree(Symbol).new + # + # # / (:root) + # tree.add "/", :root + # + # # / (:root) + # # \-products/:id (:product) + # tree.add "/products/:id", :product + # + # # / (:root) + # # \-products/ + # # +-featured (:featured) + # # \-:id (:product) + # tree.add "/products/featured", :featured + # ``` + # + # Catch all (globbing) and named parameters *path* will be located with + # lower priority against other nodes. + # + # ``` + # tree = Radix::Tree(Symbol).new + # + # # / (:root) + # tree.add "/", :root + # + # # / (:root) + # # \-*filepath (:all) + # tree.add "/*filepath", :all + # + # # / (:root) + # # +-about (:about) + # # \-*filepath (:all) + # tree.add "/about", :about + # ``` + def add(path : String, payload : T) + root = @root + + # replace placeholder with new node + if root.placeholder? + @root = Node(T).new(path, payload) + else + add path, payload, root + end + end + + # :nodoc: + private def add(path : String, payload : T, node : Node(T)) + key_reader = Char::Reader.new(node.key) + path_reader = Char::Reader.new(path) + + # move cursor position to last shared character between key and path + while path_reader.has_next? && key_reader.has_next? + break if path_reader.current_char != key_reader.current_char + + path_reader.next_char + key_reader.next_char + end + + # determine split point difference between path and key + # compare if path is larger than key + if path_reader.pos == 0 || + (path_reader.pos < path.bytesize && path_reader.pos >= node.key.bytesize) + # determine if a child of this node contains the remaining part + # of the path + added = false + + new_key = path_reader.string.byte_slice(path_reader.pos) + node.children.each do |child| + # if child's key starts with named parameter, compare key until + # separator (if present). + # Otherwise, compare just first character + if child.key[0]? == ':' && new_key[0]? == ':' + unless _same_key?(new_key, child.key) + raise SharedKeyError.new(new_key, child.key) + end + else + next unless child.key[0]? == new_key[0]? + end + + # when found, add to this child + added = true + add new_key, payload, child + break + end + + # if no existing child shared part of the key, add a new one + unless added + node.children << Node(T).new(new_key, payload) + end + + # adjust priorities + node.sort! + elsif path_reader.pos == path.bytesize && path_reader.pos == node.key.bytesize + # determine if path matches key and potentially be a duplicate + # and raise if is the case + + if node.payload? + raise DuplicateError.new(path) + else + # assign payload since this is an empty node + node.payload = payload + end + elsif path_reader.pos > 0 && path_reader.pos < node.key.bytesize + # determine if current node key needs to be split to accomodate new + # children nodes + + # build new node with partial key and adjust existing one + new_key = node.key.byte_slice(path_reader.pos) + swap_payload = node.payload? ? node.payload : nil + + new_node = Node(T).new(new_key, swap_payload) + new_node.children.replace(node.children) + + # clear payload and children (this is no longer and endpoint) + node.payload = nil + node.children.clear + + # adjust existing node key to new partial one + node.key = path_reader.string.byte_slice(0, path_reader.pos) + node.children << new_node + node.sort! + + # determine if path still continues + if path_reader.pos < path.bytesize + new_key = path.byte_slice(path_reader.pos) + node.children << Node(T).new(new_key, payload) + node.sort! + + # clear payload (no endpoint) + node.payload = nil + else + # this is an endpoint, set payload + node.payload = payload + end + end + end + + # Returns a `Result` instance after walking the tree looking up for + # *path* + # + # It will start walking the tree from the root node until a matching + # endpoint is found (or not). + # + # ``` + # tree = Radix::Tree(Symbol).new + # tree.add "/about", :about + # + # result = tree.find "/products" + # result.found? + # # => false + # + # result = tree.find "/about" + # result.found? + # # => true + # + # result.payload + # # => :about + # ``` + def find(path : String) + result = Result(T).new + root = @root + + # walk the tree from root (first time) + find path, result, root, first: true + + result + end + + # :nodoc: + private def find(path : String, result : Result, node : Node, first = false) + # special consideration when comparing the first node vs. others + # in case of node key and path being the same, return the node + # instead of walking character by character + if first && (path.bytesize == node.key.bytesize && path == node.key) && node.payload? + result.use node + return + end + + key_reader = Char::Reader.new(node.key) + path_reader = Char::Reader.new(path) + + # walk both path and key while both have characters and they continue + # to match. Consider as special cases named parameters and catch all + # rules. + while key_reader.has_next? && path_reader.has_next? && + (key_reader.current_char == '*' || + key_reader.current_char == ':' || + path_reader.current_char == key_reader.current_char) + case key_reader.current_char + when '*' + # deal with catch all (globbing) parameter + # extract parameter name from key (exclude *) and value from path + name = key_reader.string.byte_slice(key_reader.pos + 1) + value = path_reader.string.byte_slice(path_reader.pos) + + # add this to result + result.params[name] = value + + result.use node + return + when ':' + # deal with named parameter + # extract parameter name from key (from : until / or EOL) and + # value from path (same rules as key) + key_size = _detect_param_size(key_reader) + path_size = _detect_param_size(path_reader) + + # obtain key and value using calculated sizes + # for name: skip ':' by moving one character forward and compensate + # key size. + name = key_reader.string.byte_slice(key_reader.pos + 1, key_size - 1) + value = path_reader.string.byte_slice(path_reader.pos, path_size) + + # add this information to result + result.params[name] = value + + # advance readers positions + key_reader.pos += key_size + path_reader.pos += path_size + else + # move to the next character + key_reader.next_char + path_reader.next_char + end + end + + # check if we reached the end of the path & key + if !path_reader.has_next? && !key_reader.has_next? + # check endpoint + if node.payload? + result.use node + return + end + end + + # still path to walk, check for possible trailing slash or children + # nodes + if path_reader.has_next? + # using trailing slash? + if node.key.bytesize > 0 && + path_reader.pos + 1 == path.bytesize && + path_reader.current_char == '/' + result.use node + return + end + + # not found in current node, check inside children nodes + new_path = path_reader.string.byte_slice(path_reader.pos) + node.children.each do |child| + # check if child key is a named parameter, catch all or shares parts + # with new path + if (child.key[0]? == '*' || child.key[0]? == ':') || + _shared_key?(new_path, child.key) + # consider this node for key but don't use payload + result.use node, payload: false + + find new_path, result, child + return + end + end + + # path differs from key, no use searching anymore + return + end + + # key still contains characters to walk + if key_reader.has_next? + # determine if there is just a trailing slash? + if key_reader.pos + 1 == node.key.bytesize && + key_reader.current_char == '/' + result.use node + return + end + + # check if remaining part is catch all + if key_reader.pos < node.key.bytesize && + ((key_reader.current_char == '/' && key_reader.peek_next_char == '*') || + key_reader.current_char == '*') + # skip to '*' only if necessary + unless key_reader.current_char == '*' + key_reader.next_char + end + + # deal with catch all, but since there is nothing in the path + # return parameter as empty + name = key_reader.string.byte_slice(key_reader.pos + 1) + + result.params[name] = "" + + result.use node + return + end + end + end + + # :nodoc: + private def _detect_param_size(reader) + # save old position + old_pos = reader.pos + + # move forward until '/' or EOL is detected + while reader.has_next? + break if reader.current_char == '/' + + reader.next_char + end + + # calculate the size + count = reader.pos - old_pos + + # restore old position + reader.pos = old_pos + + count + end + + # Internal: allow inline comparison of *char* against 3 defined markers: + # + # - Path separator (`/`) + # - Named parameter (`:`) + # - Catch all (`*`) + @[AlwaysInline] + private def _check_markers(char) + (char == '/' || char == ':' || char == '*') + end + + # Internal: Compares *path* against *key* for differences until the + # following criteria is met: + # + # - End of *path* or *key* is reached. + # - A separator (`/`) is found. + # - A character between *path* or *key* differs + # + # ``` + # _same_key?("foo", "bar") # => false (mismatch at 1st character) + # _same_key?("foo/bar", "foo/baz") # => true (only `foo` is compared) + # _same_key?("zipcode", "zip") # => false (`zip` is shorter) + # ``` + private def _same_key?(path, key) + path_reader = Char::Reader.new(path) + key_reader = Char::Reader.new(key) + + different = false + + while (path_reader.has_next? && path_reader.current_char != '/') && + (key_reader.has_next? && key_reader.current_char != '/') + if path_reader.current_char != key_reader.current_char + different = true + break + end + + path_reader.next_char + key_reader.next_char + end + + (!different) && + (path_reader.current_char == '/' || !path_reader.has_next?) + end + + # Internal: Compares *path* against *key* for equality until one of the + # following criterias is met: + # + # - End of *path* or *key* is reached. + # - A separator (`/`) is found. + # - A named parameter (`:`) or catch all (`*`) is found. + # - A character in *path* differs from *key* + # + # ``` + # _shared_key?("foo", "bar") # => false (mismatch at 1st character) + # _shared_key?("foo/bar", "foo/baz") # => true (only `foo` is compared) + # _shared_key?("zipcode", "zip") # => true (only `zip` is compared) + # _shared_key?("s", "/new") # => false (1st character is a separator) + # ``` + private def _shared_key?(path, key) + path_reader = Char::Reader.new(path) + key_reader = Char::Reader.new(key) + + if (path_reader.current_char != key_reader.current_char) && + _check_markers(key_reader.current_char) + return false + end + + different = false + + while (path_reader.has_next? && !_check_markers(path_reader.current_char)) && + (key_reader.has_next? && !_check_markers(key_reader.current_char)) + if path_reader.current_char != key_reader.current_char + different = true + break + end + + path_reader.next_char + key_reader.next_char + end + + (!different) && + (!key_reader.has_next? || _check_markers(key_reader.current_char)) + end + + # :nodoc: + private def deprecation(message : String) + STDERR.puts message + STDERR.flush + end + end +end diff --git a/lib/radix/src/radix/version.cr b/lib/radix/src/radix/version.cr new file mode 100644 index 00000000..f29f1a7c --- /dev/null +++ b/lib/radix/src/radix/version.cr @@ -0,0 +1,3 @@ +module Radix + VERSION = "0.3.9" +end diff --git a/lib/sqlite3/.gitignore b/lib/sqlite3/.gitignore new file mode 100644 index 00000000..2e5d65e9 --- /dev/null +++ b/lib/sqlite3/.gitignore @@ -0,0 +1,10 @@ +/doc/ +/lib/ +/.crystal/ +/.shards/ + + +# Libraries don't need dependency lock +# Dependencies will be locked in application that uses them +/shard.lock + diff --git a/lib/sqlite3/.travis.yml b/lib/sqlite3/.travis.yml new file mode 100644 index 00000000..dfd532a7 --- /dev/null +++ b/lib/sqlite3/.travis.yml @@ -0,0 +1,2 @@ +language: crystal +sudo: false diff --git a/lib/sqlite3/CHANGELOG.md b/lib/sqlite3/CHANGELOG.md new file mode 100644 index 00000000..cea6dae5 --- /dev/null +++ b/lib/sqlite3/CHANGELOG.md @@ -0,0 +1,33 @@ +## v0.13.0 (2019-08-02) + +* Fix compatibility issues for Crystal 0.30.0. ([#43](https://github.com/crystal-lang/crystal-sqlite3/pull/43)) + +## v0.12.0 (2019-06-07) + +This release requires crystal >= 0.28.0 + +* Fix compatibility issues for crystal 0.29.0 ([#40](https://github.com/crystal-lang/crystal-sqlite3/pull/40)) + +## v0.11.0 (2019-04-18) + +* Fix compatibility issues for crystal 0.28.0 ([#38](https://github.com/crystal-lang/crystal-sqlite3/pull/38)) +* Add complete list of `LibSQLite3::Code` values. ([#36](https://github.com/crystal-lang/crystal-sqlite3/pull/36), thanks @t-richards) + +## v0.10.0 (2018-06-18) + +* Fix compatibility issues for crystal 0.25.0 ([#34](https://github.com/crystal-lang/crystal-sqlite3/pull/34)) + * All the time instances are translated to UTC before saving them in the db + +## v0.9.0 (2017-12-31) + +* Update to crystal-db ~> 0.5.0 + +## v0.8.3 (2017-11-07) + +* Update to crystal-db ~> 0.4.1 +* Add `SQLite3::VERSION` constant with shard version. +* Add support for multi-steps statements execution. (see [#27](https://github.com/crystal-lang/crystal-sqlite3/pull/27), thanks @t-richards) +* Fix how resources are released. (see [#23](https://github.com/crystal-lang/crystal-sqlite3/pull/23), thanks @benoist) +* Fix blob c bindings. (see [#28](https://github.com/crystal-lang/crystal-sqlite3/pull/28), thanks @rufusroflpunch) + +## v0.8.2 (2017-03-21) diff --git a/lib/sqlite3/LICENSE b/lib/sqlite3/LICENSE new file mode 100644 index 00000000..ab07ebce --- /dev/null +++ b/lib/sqlite3/LICENSE @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2016 Brian J. Cardiff + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/lib/sqlite3/README.md b/lib/sqlite3/README.md new file mode 100644 index 00000000..2ba4a429 --- /dev/null +++ b/lib/sqlite3/README.md @@ -0,0 +1,50 @@ +# crystal-sqlite3 [![Build Status](https://travis-ci.org/crystal-lang/crystal-sqlite3.svg?branch=master)](https://travis-ci.org/crystal-lang/crystal-sqlite3) + +SQLite3 bindings for [Crystal](http://crystal-lang.org/). + +Check [crystal-db](https://github.com/crystal-lang/crystal-db) for general db driver documentation. crystal-sqlite3 driver is registered under `sqlite3://` uri. + +## Installation + +Add this to your application's `shard.yml`: + +```yml +dependencies: + sqlite3: + github: crystal-lang/crystal-sqlite3 +``` + +### Usage + +```crystal +require "sqlite3" + +DB.open "sqlite3://./data.db" do |db| + db.exec "create table contacts (name text, age integer)" + db.exec "insert into contacts values (?, ?)", "John Doe", 30 + + args = [] of DB::Any + args << "Sarah" + args << 33 + db.exec "insert into contacts values (?, ?)", args + + puts "max age:" + puts db.scalar "select max(age) from contacts" # => 33 + + puts "contacts:" + db.query "select name, age from contacts order by age desc" do |rs| + puts "#{rs.column_name(0)} (#{rs.column_name(1)})" + # => name (age) + rs.each do + puts "#{rs.read(String)} (#{rs.read(Int32)})" + # => Sarah (33) + # => John Doe (30) + end + end +end +``` + +### DB::Any + +* `Time` is implemented as `TEXT` column using `SQLite3::DATE_FORMAT` format. +* `Bool` is implemented as `INT` column mapping `0`/`1` values. diff --git a/lib/sqlite3/samples/memory.cr b/lib/sqlite3/samples/memory.cr new file mode 100644 index 00000000..87e02cc6 --- /dev/null +++ b/lib/sqlite3/samples/memory.cr @@ -0,0 +1,26 @@ +require "db" +require "../src/sqlite3" + +DB.open "sqlite3://%3Amemory%3A" do |db| + db.exec "create table contacts (name text, age integer)" + db.exec "insert into contacts values (?, ?)", "John Doe", 30 + + args = [] of DB::Any + args << "Sarah" + args << 33 + db.exec "insert into contacts values (?, ?)", args + + puts "max age:" + puts db.scalar "select max(age) from contacts" # => 33 + + puts "contacts:" + db.query "select name, age from contacts order by age desc" do |rs| + puts "#{rs.column_name(0)} (#{rs.column_name(1)})" + # => name (age) + rs.each do + puts "#{rs.read(String)} (#{rs.read(Int32)})" + # => Sarah (33) + # => John Doe (30) + end + end +end diff --git a/lib/sqlite3/shard.yml b/lib/sqlite3/shard.yml new file mode 100644 index 00000000..a9852e25 --- /dev/null +++ b/lib/sqlite3/shard.yml @@ -0,0 +1,15 @@ +name: sqlite3 +version: 0.13.0 + +dependencies: + db: + github: crystal-lang/crystal-db + version: ~> 0.6.0 + +authors: + - Ary Borenszweig + - Brian J. Cardiff + +crystal: 0.28.0 + +license: MIT diff --git a/lib/sqlite3/spec/connection_spec.cr b/lib/sqlite3/spec/connection_spec.cr new file mode 100644 index 00000000..be3bf0ce --- /dev/null +++ b/lib/sqlite3/spec/connection_spec.cr @@ -0,0 +1,71 @@ +require "./spec_helper" + +private def dump(source, target) + source.using_connection do |conn| + conn = conn.as(SQLite3::Connection) + target.using_connection do |backup_conn| + backup_conn = backup_conn.as(SQLite3::Connection) + conn.dump(backup_conn) + end + end +end + +describe Connection do + it "opens a database and then backs it up to another db" do + with_db do |db| + with_db("./test2.db") do |backup_db| + db.exec "create table person (name text, age integer)" + db.exec "insert into person values (\"foo\", 10)" + + dump db, backup_db + + backup_name = backup_db.scalar "select name from person" + backup_age = backup_db.scalar "select age from person" + source_name = db.scalar "select name from person" + source_age = db.scalar "select age from person" + + {backup_name, backup_age}.should eq({source_name, source_age}) + end + end + end + + it "opens a database, inserts records, dumps to an in-memory db, insers some more, then dumps to the source" do + with_db do |db| + with_mem_db do |in_memory_db| + db.exec "create table person (name text, age integer)" + db.exec "insert into person values (\"foo\", 10)" + dump db, in_memory_db + + in_memory_db.scalar("select count(*) from person").should eq(1) + in_memory_db.exec "insert into person values (\"bar\", 22)" + dump in_memory_db, db + + db.scalar("select count(*) from person").should eq(2) + end + end + end + + it "opens a database, inserts records (>1024K), and dumps to an in-memory db" do + with_db do |db| + with_mem_db do |in_memory_db| + db.exec "create table person (name text, age integer)" + db.transaction do |tx| + 100_000.times { tx.connection.exec "insert into person values (\"foo\", 10)" } + end + dump db, in_memory_db + in_memory_db.scalar("select count(*) from person").should eq(100_000) + end + end + end + + it "opens a connection without the pool" do + with_cnn do |cnn| + cnn.should be_a(SQLite3::Connection) + + cnn.exec "create table person (name text, age integer)" + cnn.exec "insert into person values (\"foo\", 10)" + + cnn.scalar("select count(*) from person").should eq(1) + end + end +end diff --git a/lib/sqlite3/spec/db_spec.cr b/lib/sqlite3/spec/db_spec.cr new file mode 100644 index 00000000..3a2fa1f0 --- /dev/null +++ b/lib/sqlite3/spec/db_spec.cr @@ -0,0 +1,131 @@ +require "./spec_helper" +require "db/spec" + +private class NotSupportedType +end + +private def cast_if_blob(expr, sql_type) + case sql_type + when "blob" + "cast(#{expr} as blob)" + else + expr + end +end + +DB::DriverSpecs(DB::Any).run do + support_unprepared false + + before do + File.delete(DB_FILENAME) if File.exists?(DB_FILENAME) + end + after do + File.delete(DB_FILENAME) if File.exists?(DB_FILENAME) + end + + connection_string "sqlite3:#{DB_FILENAME}" + # ? can use many ... (:memory:) + + sample_value true, "int", "1", type_safe_value: false + sample_value false, "int", "0", type_safe_value: false + sample_value 2, "int", "2", type_safe_value: false + sample_value 1_i64, "int", "1" + sample_value "hello", "text", "'hello'" + sample_value 1.5_f32, "float", "1.5", type_safe_value: false + sample_value 1.5, "float", "1.5" + sample_value Time.utc(2016, 2, 15), "text", "'2016-02-15 00:00:00.000'", type_safe_value: false + sample_value Time.utc(2016, 2, 15, 10, 15, 30), "text", "'2016-02-15 10:15:30.000'", type_safe_value: false + sample_value Time.local(2016, 2, 15, 7, 15, 30, location: Time::Location.fixed("fixed", -3*3600)), "text", "'2016-02-15 10:15:30.000'", type_safe_value: false + + ary = UInt8[0x53, 0x51, 0x4C, 0x69, 0x74, 0x65] + sample_value Bytes.new(ary.to_unsafe, ary.size), "blob", "X'53514C697465'" # , type_safe_value: false + + binding_syntax do |index| + "?" + end + + create_table_1column_syntax do |table_name, col1| + "create table #{table_name} (#{col1.name} #{col1.sql_type} #{col1.null ? "NULL" : "NOT NULL"})" + end + + create_table_2columns_syntax do |table_name, col1, col2| + "create table #{table_name} (#{col1.name} #{col1.sql_type} #{col1.null ? "NULL" : "NOT NULL"}, #{col2.name} #{col2.sql_type} #{col2.null ? "NULL" : "NOT NULL"})" + end + + select_1column_syntax do |table_name, col1| + "select #{cast_if_blob(col1.name, col1.sql_type)} from #{table_name}" + end + + select_2columns_syntax do |table_name, col1, col2| + "select #{cast_if_blob(col1.name, col1.sql_type)}, #{cast_if_blob(col2.name, col2.sql_type)} from #{table_name}" + end + + select_count_syntax do |table_name| + "select count(*) from #{table_name}" + end + + select_scalar_syntax do |expression, sql_type| + "select #{cast_if_blob(expression, sql_type)}" + end + + insert_1column_syntax do |table_name, col, expression| + "insert into #{table_name} (#{col.name}) values (#{expression})" + end + + insert_2columns_syntax do |table_name, col1, expr1, col2, expr2| + "insert into #{table_name} (#{col1.name}, #{col2.name}) values (#{expr1}, #{expr2})" + end + + drop_table_if_exists_syntax do |table_name| + "drop table if exists #{table_name}" + end + + it "gets last insert row id", prepared: :both do |db| + db.exec "create table person (name text, age integer)" + db.exec %(insert into person values ("foo", 10)) + res = db.exec %(insert into person values ("foo", 10)) + res.last_insert_id.should eq(2) + res.rows_affected.should eq(1) + end + + # TODO timestamp support + + it "raises on unsupported param types" do |db| + expect_raises Exception, "SQLite3::Statement does not support NotSupportedType params" do + db.query "select ?", NotSupportedType.new + end + # TODO raising exception does not close the connection and pool is exhausted + end + + it "ensures statements are closed" do |db| + db.exec %(create table if not exists a (i int not null, str text not null);) + db.exec %(insert into a (i, str) values (23, "bai bai");) + + 2.times do |i| + DB.open db.uri do |db| + begin + db.query("SELECT i, str FROM a WHERE i = ?", 23) do |rs| + rs.move_next + break + end + rescue e : SQLite3::Exception + fail("Expected no exception, but got \"#{e.message}\"") + end + + begin + db.exec("UPDATE a SET i = ? WHERE i = ?", 23, 23) + rescue e : SQLite3::Exception + fail("Expected no exception, but got \"#{e.message}\"") + end + end + end + end + + it "handles single-step pragma statements" do |db| + db.exec %(PRAGMA synchronous = OFF) + end + + it "handles multi-step pragma statements" do |db| + db.exec %(PRAGMA journal_mode = memory) + end +end diff --git a/lib/sqlite3/spec/driver_spec.cr b/lib/sqlite3/spec/driver_spec.cr new file mode 100644 index 00000000..fbd64b55 --- /dev/null +++ b/lib/sqlite3/spec/driver_spec.cr @@ -0,0 +1,34 @@ +require "./spec_helper" + +def assert_filename(uri, filename) + SQLite3::Connection.filename(URI.parse(uri)).should eq(filename) +end + +describe Driver do + it "should register sqlite3 name" do + DB.driver_class("sqlite3").should eq(SQLite3::Driver) + end + + it "should get filename from uri" do + assert_filename("sqlite3:%3Amemory%3A", ":memory:") + assert_filename("sqlite3://%3Amemory%3A", ":memory:") + + assert_filename("sqlite3:./file.db", "./file.db") + assert_filename("sqlite3://./file.db", "./file.db") + + assert_filename("sqlite3:/path/to/file.db", "/path/to/file.db") + assert_filename("sqlite3:///path/to/file.db", "/path/to/file.db") + + assert_filename("sqlite3:./file.db?max_pool_size=5", "./file.db") + assert_filename("sqlite3:/path/to/file.db?max_pool_size=5", "/path/to/file.db") + assert_filename("sqlite3://./file.db?max_pool_size=5", "./file.db") + assert_filename("sqlite3:///path/to/file.db?max_pool_size=5", "/path/to/file.db") + end + + it "should use database option as file to open" do + with_db do |db| + db.driver.should be_a(SQLite3::Driver) + File.exists?(DB_FILENAME).should be_true + end + end +end diff --git a/lib/sqlite3/spec/pool_spec.cr b/lib/sqlite3/spec/pool_spec.cr new file mode 100644 index 00000000..4971b94b --- /dev/null +++ b/lib/sqlite3/spec/pool_spec.cr @@ -0,0 +1,32 @@ +require "./spec_helper" + +describe DB::Pool do + it "should write from multiple connections" do + channel = Channel(Nil).new + fibers = 5 + max_n = 50 + with_db "#{DB_FILENAME}?max_pool_size=#{fibers}" do |db| + db.exec "create table numbers (n integer, fiber integer)" + + fibers.times do |f| + spawn do + (1..max_n).each do |n| + db.exec "insert into numbers (n, fiber) values (?, ?)", n, f + sleep 0.01 + end + channel.send nil + end + end + + fibers.times { channel.receive } + + # all numbers were inserted + s = fibers * max_n * (max_n + 1) // 2 + db.scalar("select sum(n) from numbers").should eq(s) + + # numbers were not inserted one fiber at a time + rows = db.query_all "select n, fiber from numbers", as: {Int32, Int32} + rows.map(&.[1]).should_not eq(rows.map(&.[1]).sort) + end + end +end diff --git a/lib/sqlite3/spec/spec_helper.cr b/lib/sqlite3/spec/spec_helper.cr new file mode 100644 index 00000000..d83e8450 --- /dev/null +++ b/lib/sqlite3/spec/spec_helper.cr @@ -0,0 +1,33 @@ +require "spec" +require "../src/sqlite3" + +include SQLite3 + +DB_FILENAME = "./test.db" + +def with_db(&block : DB::Database ->) + File.delete(DB_FILENAME) rescue nil + DB.open "sqlite3:#{DB_FILENAME}", &block +ensure + File.delete(DB_FILENAME) +end + +def with_cnn(&block : DB::Connection ->) + File.delete(DB_FILENAME) rescue nil + DB.connect "sqlite3:#{DB_FILENAME}", &block +ensure + File.delete(DB_FILENAME) +end + +def with_db(config, &block : DB::Database ->) + uri = "sqlite3:#{config}" + filename = SQLite3::Connection.filename(URI.parse(uri)) + File.delete(filename) rescue nil + DB.open uri, &block +ensure + File.delete(filename) if filename +end + +def with_mem_db(&block : DB::Database ->) + DB.open "sqlite3://%3Amemory%3A", &block +end diff --git a/lib/sqlite3/src/sqlite3.cr b/lib/sqlite3/src/sqlite3.cr new file mode 100644 index 00000000..77ff98fc --- /dev/null +++ b/lib/sqlite3/src/sqlite3.cr @@ -0,0 +1,9 @@ +require "db" +require "./sqlite3/**" + +module SQLite3 + DATE_FORMAT = "%F %H:%M:%S.%L" + + # :nodoc: + TIME_ZONE = Time::Location::UTC +end diff --git a/lib/sqlite3/src/sqlite3/connection.cr b/lib/sqlite3/src/sqlite3/connection.cr new file mode 100644 index 00000000..5079c3c9 --- /dev/null +++ b/lib/sqlite3/src/sqlite3/connection.cr @@ -0,0 +1,93 @@ +class SQLite3::Connection < DB::Connection + def initialize(database) + super + filename = self.class.filename(database.uri) + # TODO maybe enable Flag::URI to parse query string in the uri as additional flags + check LibSQLite3.open_v2(filename, out @db, (Flag::READWRITE | Flag::CREATE), nil) + rescue + raise DB::ConnectionRefused.new + end + + def self.filename(uri : URI) + {% if compare_versions(Crystal::VERSION, "0.30.0-0") >= 0 %} + URI.decode_www_form((uri.host || "") + uri.path) + {% else %} + URI.unescape((uri.host || "") + uri.path) + {% end %} + end + + def build_prepared_statement(query) : Statement + Statement.new(self, query) + end + + def build_unprepared_statement(query) : Statement + # sqlite3 does not support unprepared statement. + # All statements once prepared should be released + # when unneeded. Unprepared statement are not aim + # to leave state in the connection. Mimicking them + # with prepared statement would be wrong with + # respect connection resources. + raise DB::Error.new("SQLite3 driver does not support unprepared statements") + end + + def do_close + super + check LibSQLite3.close(self) + end + + # :nodoc: + def perform_begin_transaction + self.prepared.exec "BEGIN" + end + + # :nodoc: + def perform_commit_transaction + self.prepared.exec "COMMIT" + end + + # :nodoc: + def perform_rollback_transaction + self.prepared.exec "ROLLBACK" + end + + # :nodoc: + def perform_create_savepoint(name) + self.prepared.exec "SAVEPOINT #{name}" + end + + # :nodoc: + def perform_release_savepoint(name) + self.prepared.exec "RELEASE SAVEPOINT #{name}" + end + + # :nodoc: + def perform_rollback_savepoint(name) + self.prepared.exec "ROLLBACK TO #{name}" + end + + # Dump the database to another SQLite3 database. This can be used for backing up a SQLite3 Database + # to disk or the opposite + def dump(to : SQLite3::Connection) + backup_item = LibSQLite3.backup_init(to.@db, "main", @db, "main") + if backup_item.null? + raise Exception.new(to.@db) + end + code = LibSQLite3.backup_step(backup_item, -1) + + if code != LibSQLite3::Code::DONE + raise Exception.new(to.@db) + end + code = LibSQLite3.backup_finish(backup_item) + if code != LibSQLite3::Code::OKAY + raise Exception.new(to.@db) + end + end + + def to_unsafe + @db + end + + private def check(code) + raise Exception.new(self) unless code == 0 + end +end diff --git a/lib/sqlite3/src/sqlite3/driver.cr b/lib/sqlite3/src/sqlite3/driver.cr new file mode 100644 index 00000000..d43a5126 --- /dev/null +++ b/lib/sqlite3/src/sqlite3/driver.cr @@ -0,0 +1,7 @@ +class SQLite3::Driver < DB::Driver + def build_connection(context : DB::ConnectionContext) : SQLite3::Connection + SQLite3::Connection.new(context) + end +end + +DB.register_driver "sqlite3", SQLite3::Driver diff --git a/lib/sqlite3/src/sqlite3/exception.cr b/lib/sqlite3/src/sqlite3/exception.cr new file mode 100644 index 00000000..59401443 --- /dev/null +++ b/lib/sqlite3/src/sqlite3/exception.cr @@ -0,0 +1,10 @@ +# Exception thrown on invalid SQLite3 operations. +class SQLite3::Exception < ::Exception + # The internal code associated with the failure. + getter code + + def initialize(db) + super(String.new(LibSQLite3.errmsg(db))) + @code = LibSQLite3.errcode(db) + end +end diff --git a/lib/sqlite3/src/sqlite3/flags.cr b/lib/sqlite3/src/sqlite3/flags.cr new file mode 100644 index 00000000..ccb84569 --- /dev/null +++ b/lib/sqlite3/src/sqlite3/flags.cr @@ -0,0 +1,30 @@ +@[Flags] +enum SQLite3::Flag + READONLY = 0x00000001 # Ok for sqlite3_open_v2() + READWRITE = 0x00000002 # Ok for sqlite3_open_v2() + CREATE = 0x00000004 # Ok for sqlite3_open_v2() + DELETEONCLOSE = 0x00000008 # VFS only + EXCLUSIVE = 0x00000010 # VFS only + AUTOPROXY = 0x00000020 # VFS only + URI = 0x00000040 # Ok for sqlite3_open_v2() + MEMORY = 0x00000080 # Ok for sqlite3_open_v2() + MAIN_DB = 0x00000100 # VFS only + TEMP_DB = 0x00000200 # VFS only + TRANSIENT_DB = 0x00000400 # VFS only + MAIN_JOURNAL = 0x00000800 # VFS only + TEMP_JOURNAL = 0x00001000 # VFS only + SUBJOURNAL = 0x00002000 # VFS only + MASTER_JOURNAL = 0x00004000 # VFS only + NOMUTEX = 0x00008000 # Ok for sqlite3_open_v2() + FULLMUTEX = 0x00010000 # Ok for sqlite3_open_v2() + SHAREDCACHE = 0x00020000 # Ok for sqlite3_open_v2() + PRIVATECACHE = 0x00040000 # Ok for sqlite3_open_v2() + WAL = 0x00080000 # VFS only +end + +module SQLite3 + # Same as doing SQLite3::Flag.flag(*values) + macro flags(*values) + ::SQLite3::Flag.flags({{*values}}) + end +end diff --git a/lib/sqlite3/src/sqlite3/lib_sqlite3.cr b/lib/sqlite3/src/sqlite3/lib_sqlite3.cr new file mode 100644 index 00000000..d99f67c0 --- /dev/null +++ b/lib/sqlite3/src/sqlite3/lib_sqlite3.cr @@ -0,0 +1,111 @@ +require "./type" + +@[Link("sqlite3")] +lib LibSQLite3 + type SQLite3 = Void* + type Statement = Void* + type SQLite3Backup = Void* + + enum Code + # Successful result + OKAY = 0 + # Generic error + ERROR = 1 + # Internal logic error in SQLite + INTERNAL = 2 + # Access permission denied + PERM = 3 + # Callback routine requested an abort + ABORT = 4 + # The database file is locked + BUSY = 5 + # A table in the database is locked + LOCKED = 6 + # A malloc() failed + NOMEM = 7 + # Attempt to write a readonly database + READONLY = 8 + # Operation terminated by sqlite3_interrupt() + INTERRUPT = 9 + # Some kind of disk I/O error occurred + IOERR = 10 + # The database disk image is malformed + CORRUPT = 11 + # Unknown opcode in sqlite3_file_control() + NOTFOUND = 12 + # Insertion failed because database is full + FULL = 13 + # Unable to open the database file + CANTOPEN = 14 + # Database lock protocol error + PROTOCOL = 15 + # Internal use only + EMPTY = 16 + # The database schema changed + SCHEMA = 17 + # String or BLOB exceeds size limit + TOOBIG = 18 + # Abort due to constraint violation + CONSTRAINT = 19 + # Data type mismatch + MISMATCH = 20 + # Library used incorrectly + MISUSE = 21 + # Uses OS features not supported on host + NOLFS = 22 + # Authorization denied + AUTH = 23 + # Not used + FORMAT = 24 + # 2nd parameter to sqlite3_bind out of range + RANGE = 25 + # File opened that is not a database file + NOTADB = 26 + # Notifications from sqlite3_log() + NOTICE = 27 + # Warnings from sqlite3_log() + WARNING = 28 + # sqlite3_step() has another row ready + ROW = 100 + # sqlite3_step() has finished executing + DONE = 101 + end + + alias Callback = (Void*, Int32, UInt8**, UInt8**) -> Int32 + + fun open_v2 = sqlite3_open_v2(filename : UInt8*, db : SQLite3*, flags : ::SQLite3::Flag, zVfs : UInt8*) : Int32 + + fun errcode = sqlite3_errcode(SQLite3) : Int32 + fun errmsg = sqlite3_errmsg(SQLite3) : UInt8* + + fun backup_init = sqlite3_backup_init(SQLite3, UInt8*, SQLite3, UInt8*) : SQLite3Backup + fun backup_step = sqlite3_backup_step(SQLite3Backup, Int32) : Code + fun backup_finish = sqlite3_backup_finish(SQLite3Backup) : Code + + fun prepare_v2 = sqlite3_prepare_v2(db : SQLite3, zSql : UInt8*, nByte : Int32, ppStmt : Statement*, pzTail : UInt8**) : Int32 + fun step = sqlite3_step(stmt : Statement) : Int32 + fun column_count = sqlite3_column_count(stmt : Statement) : Int32 + fun column_type = sqlite3_column_type(stmt : Statement, iCol : Int32) : ::SQLite3::Type + fun column_int64 = sqlite3_column_int64(stmt : Statement, iCol : Int32) : Int64 + fun column_double = sqlite3_column_double(stmt : Statement, iCol : Int32) : Float64 + fun column_text = sqlite3_column_text(stmt : Statement, iCol : Int32) : UInt8* + fun column_bytes = sqlite3_column_bytes(stmt : Statement, iCol : Int32) : Int32 + fun column_blob = sqlite3_column_blob(stmt : Statement, iCol : Int32) : UInt8* + + fun bind_int = sqlite3_bind_int(stmt : Statement, idx : Int32, value : Int32) : Int32 + fun bind_int64 = sqlite3_bind_int64(stmt : Statement, idx : Int32, value : Int64) : Int32 + fun bind_text = sqlite3_bind_text(stmt : Statement, idx : Int32, value : UInt8*, bytes : Int32, destructor : Void* ->) : Int32 + fun bind_blob = sqlite3_bind_blob(stmt : Statement, idx : Int32, value : UInt8*, bytes : Int32, destructor : Void* ->) : Int32 + fun bind_null = sqlite3_bind_null(stmt : Statement, idx : Int32) : Int32 + fun bind_double = sqlite3_bind_double(stmt : Statement, idx : Int32, value : Float64) : Int32 + + fun bind_parameter_index = sqlite3_bind_parameter_index(stmt : Statement, name : UInt8*) : Int32 + fun reset = sqlite3_reset(stmt : Statement) : Int32 + fun column_name = sqlite3_column_name(stmt : Statement, idx : Int32) : UInt8* + fun last_insert_rowid = sqlite3_last_insert_rowid(db : SQLite3) : Int64 + fun changes = sqlite3_changes(db : SQLite3) : Int32 + + fun finalize = sqlite3_finalize(stmt : Statement) : Int32 + fun close_v2 = sqlite3_close_v2(SQLite3) : Int32 + fun close = sqlite3_close(SQLite3) : Int32 +end diff --git a/lib/sqlite3/src/sqlite3/result_set.cr b/lib/sqlite3/src/sqlite3/result_set.cr new file mode 100644 index 00000000..38e669f2 --- /dev/null +++ b/lib/sqlite3/src/sqlite3/result_set.cr @@ -0,0 +1,108 @@ +class SQLite3::ResultSet < DB::ResultSet + @column_index = 0 + + protected def do_close + super + LibSQLite3.reset(self) + end + + # Advances to the next row. Returns `true` if there's a next row, + # `false` otherwise. Must be called at least once to advance to the first + # row. + def move_next : Bool + @column_index = 0 + + case step + when LibSQLite3::Code::ROW + true + when LibSQLite3::Code::DONE + false + else + raise Exception.new(sqlite3_statement.sqlite3_connection) + end + end + + def read + col = @column_index + value = + case LibSQLite3.column_type(self, col) + when Type::INTEGER + LibSQLite3.column_int64(self, col) + when Type::FLOAT + LibSQLite3.column_double(self, col) + when Type::BLOB + blob = LibSQLite3.column_blob(self, col) + bytes = LibSQLite3.column_bytes(self, col) + ptr = Pointer(UInt8).malloc(bytes) + ptr.copy_from(blob, bytes) + Bytes.new(ptr, bytes) + when Type::TEXT + String.new(LibSQLite3.column_text(self, col)) + when Type::NULL + nil + else + raise Exception.new(sqlite3_statement.sqlite3_connection) + end + @column_index += 1 + value + end + + def read(t : Int32.class) : Int32 + read(Int64).to_i32 + end + + def read(type : Int32?.class) : Int32? + read(Int64?).try &.to_i32 + end + + def read(t : Float32.class) : Float32 + read(Float64).to_f32 + end + + def read(type : Float32?.class) : Float32? + read(Float64?).try &.to_f32 + end + + def read(t : Time.class) : Time + Time.parse read(String), SQLite3::DATE_FORMAT, location: SQLite3::TIME_ZONE + end + + def read(t : Time?.class) : Time? + read(String?).try { |v| Time.parse(v, SQLite3::DATE_FORMAT, location: SQLite3::TIME_ZONE) } + end + + def read(t : Bool.class) : Bool + read(Int64) != 0 + end + + def read(t : Bool?.class) : Bool? + read(Int64?).try &.!=(0) + end + + def column_count : Int32 + LibSQLite3.column_count(self) + end + + def column_name(index) : String + String.new LibSQLite3.column_name(self, index) + end + + def to_unsafe + sqlite3_statement.to_unsafe + end + + # :nodoc: + private def step + LibSQLite3::Code.new LibSQLite3.step(sqlite3_statement) + end + + protected def sqlite3_statement + @statement.as(Statement) + end + + private def moving_column + res = yield @column_index + @column_index += 1 + res + end +end diff --git a/lib/sqlite3/src/sqlite3/statement.cr b/lib/sqlite3/src/sqlite3/statement.cr new file mode 100644 index 00000000..c5dc619a --- /dev/null +++ b/lib/sqlite3/src/sqlite3/statement.cr @@ -0,0 +1,91 @@ +class SQLite3::Statement < DB::Statement + def initialize(connection, sql) + super(connection) + check LibSQLite3.prepare_v2(sqlite3_connection, sql, sql.bytesize + 1, out @stmt, nil) + end + + protected def perform_query(args : Enumerable) : DB::ResultSet + LibSQLite3.reset(self) + args.each_with_index(1) do |arg, index| + bind_arg(index, arg) + end + ResultSet.new(self) + end + + protected def perform_exec(args : Enumerable) : DB::ExecResult + LibSQLite3.reset(self.to_unsafe) + args.each_with_index(1) do |arg, index| + bind_arg(index, arg) + end + + # exec + step = uninitialized LibSQLite3::Code + loop do + step = LibSQLite3::Code.new LibSQLite3.step(self) + break unless step == LibSQLite3::Code::ROW + end + raise Exception.new(sqlite3_connection) unless step == LibSQLite3::Code::DONE + + rows_affected = LibSQLite3.changes(sqlite3_connection).to_i64 + last_id = LibSQLite3.last_insert_rowid(sqlite3_connection) + + DB::ExecResult.new rows_affected, last_id + end + + protected def do_close + super + check LibSQLite3.finalize(self) + end + + private def bind_arg(index, value : Nil) + check LibSQLite3.bind_null(self, index) + end + + private def bind_arg(index, value : Bool) + check LibSQLite3.bind_int(self, index, value ? 1 : 0) + end + + private def bind_arg(index, value : Int32) + check LibSQLite3.bind_int(self, index, value) + end + + private def bind_arg(index, value : Int64) + check LibSQLite3.bind_int64(self, index, value) + end + + private def bind_arg(index, value : Float32) + check LibSQLite3.bind_double(self, index, value.to_f64) + end + + private def bind_arg(index, value : Float64) + check LibSQLite3.bind_double(self, index, value) + end + + private def bind_arg(index, value : String) + check LibSQLite3.bind_text(self, index, value, value.bytesize, nil) + end + + private def bind_arg(index, value : Bytes) + check LibSQLite3.bind_blob(self, index, value, value.size, nil) + end + + private def bind_arg(index, value : Time) + bind_arg(index, value.in(SQLite3::TIME_ZONE).to_s(SQLite3::DATE_FORMAT)) + end + + private def bind_arg(index, value) + raise "#{self.class} does not support #{value.class} params" + end + + private def check(code) + raise Exception.new(sqlite3_connection) unless code == 0 + end + + protected def sqlite3_connection + @connection.as(Connection) + end + + def to_unsafe + @stmt + end +end diff --git a/lib/sqlite3/src/sqlite3/type.cr b/lib/sqlite3/src/sqlite3/type.cr new file mode 100644 index 00000000..a4abe0e4 --- /dev/null +++ b/lib/sqlite3/src/sqlite3/type.cr @@ -0,0 +1,8 @@ +# Each of the possible types of an SQLite3 column. +enum SQLite3::Type + INTEGER = 1 + FLOAT = 2 + BLOB = 4 + NULL = 5 + TEXT = 3 +end diff --git a/lib/sqlite3/src/sqlite3/version.cr b/lib/sqlite3/src/sqlite3/version.cr new file mode 100644 index 00000000..1aaa48d0 --- /dev/null +++ b/lib/sqlite3/src/sqlite3/version.cr @@ -0,0 +1,3 @@ +module SQLite3 + VERSION = "0.13.0" +end diff --git a/shard.lock b/shard.lock new file mode 100644 index 00000000..c9704a1b --- /dev/null +++ b/shard.lock @@ -0,0 +1,30 @@ +version: 1.0 +shards: + db: + github: crystal-lang/crystal-db + version: 0.6.0 + + exception_page: + github: crystal-loot/exception_page + version: 0.1.2 + + kemal: + github: kemalcr/kemal + version: 0.26.0 + + kilt: + github: jeromegn/kilt + version: 0.4.0 + + pg: + github: will/crystal-pg + version: 0.18.1 + + radix: + github: luislavena/radix + version: 0.3.9 + + sqlite3: + github: crystal-lang/crystal-sqlite3 + version: 0.13.0 +